-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_importance_score_imagenet.py
125 lines (101 loc) · 4.67 KB
/
generate_importance_score_imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import numpy as np
import torchvision
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
import os, sys
import argparse
import pickle
from core.data import IndexDataset, ImageNetDataset, CustomImageNetDataset
import torch.nn.functional as F
parser = argparse.ArgumentParser()
######################### Data Setting #########################
parser.add_argument('--data-dir', type=str, default='../datasets/',
help='The dir path of the data.')
parser.add_argument('--base-dir', type=str)
parser.add_argument('--task-name', type=str)
parser.add_argument('--data-score-path', type=str)
################### Load Pseudo Labels from DL models ###################
parser.add_argument('--load_pseudo', action='store_true', default=False)
parser.add_argument('--pseudo_train_label_path', type=str, help='Path for the pseudo train labels')
args = parser.parse_args()
def EL2N(td_log, data_importance, max_epoch):
l2_loss = torch.nn.MSELoss(reduction='none')
def record_training_dynamics(td_log):
output = F.softmax(torch.tensor(td_log['output'], dtype=torch.float32))
index = td_log['idx'].type(torch.long)
label = targets[index]
label_onehot = torch.nn.functional.one_hot(label, num_classes=num_classes)
el2n_score = torch.sqrt(l2_loss(label_onehot, output).sum(dim=1))
data_importance['el2n'][index] += el2n_score
for i, item in enumerate(td_log):
if item['epoch'] == max_epoch:
return
record_training_dynamics(item)
def training_dynamics_metrics(td_log, data_importance):
def record_training_dynamics(td_log):
output = F.softmax(torch.tensor(td_log['output'] , dtype=torch.float32))
predicted = output.argmax(dim=1)
index = td_log['idx'].type(torch.long)
label = targets[index]
correctness = (predicted == label).type(torch.int)
data_importance['forgetting'][index] += torch.logical_and(data_importance['last_correctness'][index] == 1, correctness == 0)
data_importance['last_correctness'][index] = correctness
data_importance['correctness'][index] += data_importance['last_correctness'][index]
batch_idx = range(output.shape[0])
target_prob = output[batch_idx, label]
output[batch_idx, label] = 0
other_highest_prob = torch.max(output, dim=1)[0]
margin = target_prob - other_highest_prob
data_importance['accumulated_margin'][index] += margin
for i, item in enumerate(td_log):
record_training_dynamics(item)
#Load all data
data_dir = args.data_dir
train_labels = None
if args.load_pseudo:
print(f'Loading pseudo labels from {args.pseudo_train_label_path}')
train_labels = torch.load(args.pseudo_train_label_path)
trainset = CustomImageNetDataset(path=os.path.join(args.data_dir, 'train'),
pseudo_labels=train_labels)
trainset = IndexDataset(trainset)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=800, shuffle=False, pin_memory=True, num_workers=16)
print('First 100 train label:')
print([int(trainset[i][1][1]) for i in range(100)])
# Load all targets into array
targets = []
print(f'Load label info from datasets...')
print(f'Total batch: {len(trainloader)}')
for batch_idx, (idx, (_, y)) in enumerate(trainloader):
targets += list(y.numpy())
if batch_idx % 50 == 0:
print(batch_idx)
print(len(targets))
data_importance = {}
targets = torch.tensor(targets)
data_size = targets.shape[0]
num_classes = 1000
data_importance['targets'] = targets.type(torch.int32)
data_importance['el2n'] = torch.zeros(data_size).type(torch.float32)
data_importance['correctness'] = torch.zeros(data_size).type(torch.int32)
data_importance['forgetting'] = torch.zeros(data_size).type(torch.int32)
data_importance['last_correctness'] = torch.zeros(data_size).type(torch.int32)
data_importance['accumulated_margin'] = torch.zeros(data_size).type(torch.float32)
for i in range(1,11):
td_path = f"{args.base_dir}/{args.task_name}/training-dynamics/td-{args.task_name}-epoch-{i}.pickle"
print(td_path)
with open(td_path, 'rb') as f:
td_data = pickle.load(f)
EL2N(td_data['training_dynamics'], data_importance, max_epoch=11)
for i in range(1,61):
td_path = f"{args.base_dir}/{args.task_name}/training-dynamics/td-{args.task_name}-epoch-{i}.pickle"
print(td_path)
with open(td_path, 'rb') as f:
td_data = pickle.load(f)
training_dynamics_metrics(td_data['training_dynamics'], data_importance)
data_score_path = args.data_score_path
print(f'Saving data score at {data_score_path}')
with open(data_score_path, 'wb') as handle:
pickle.dump(data_importance, handle)