Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

save_features.py 9.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
  1. import numpy as np
  2. import torch
  3. from torch.autograd import Variable
  4. import os
  5. import glob
  6. import h5py
  7. import config.configs as configs
  8. import models.backbone as backbone
  9. # from data.datamgr import SimpleDataManager
  10. # from methods.baselinetrain import BaselineTrain
  11. # from methods.baselinefinetune import BaselineFinetune
  12. # from methods.protonet import ProtoNet
  13. # from methods.matchingnet import MatchingNet
  14. # from methods.relationnet import RelationNet
  15. # from methods.maml import MAML
  16. from utils.io_utils import model_dict, parse_args, get_resume_file, get_best_file, get_assigned_file
  17. from models.model_resnet import *
  18. def save_features(model, data_loader, outfile ):
  19. f = h5py.File(outfile, 'w')
  20. max_count = len(data_loader)*data_loader.batch_size
  21. all_labels = f.create_dataset('all_labels',(max_count,), dtype='i')
  22. all_feats=None
  23. count=0
  24. for i, (x,y) in enumerate(data_loader):
  25. if i%10 == 0:
  26. print('{:d}/{:d}'.format(i, len(data_loader)))
  27. x = x.cuda()
  28. x_var = Variable(x)
  29. feats = model(x_var)
  30. if all_feats is None:
  31. all_feats = f.create_dataset('all_feats', [max_count] + list( feats.size()[1:]) , dtype='f')
  32. all_feats[count:count+feats.size(0)] = feats.data.cpu().numpy()
  33. all_labels[count:count+feats.size(0)] = y.cpu().numpy()
  34. count = count + feats.size(0)
  35. count_var = f.create_dataset('count', (1,), dtype='i')
  36. count_var[0] = count
  37. f.close()
  38. def save_features_depth(model, data_loader, outfile ):
  39. f = h5py.File(outfile, 'w')
  40. max_count = len(data_loader)*data_loader.batch_size
  41. all_labels = f.create_dataset('all_labels',(max_count,), dtype='i')
  42. all_feats=None
  43. all_feats_depth=None
  44. count=0
  45. for i, (x,y,depth) in enumerate(data_loader):
  46. if i%10 == 0:
  47. print('{:d}/{:d}'.format(i, len(data_loader)))
  48. # import ipdb; ipdb.set_trace()
  49. feats = model.set_forward_depth(x, depth, is_feature=False, get_feature=True)
  50. # foreground = torch.mul(x,(depth>0.33).float())
  51. # background = torch.mul(x,(depth<0.33).float())
  52. # feats = model.feature(foreground.cuda())
  53. # feats_depth = model.feature_depth(background.cuda())
  54. # x = x.cuda()
  55. # depth = depth.cuda()
  56. # x_var = Variable(x)
  57. # feats = model.feature(x)
  58. # feats_depth = model.feature_depth(depth)
  59. # import ipdb; ipdb.set_trace()
  60. if all_feats is None:
  61. all_feats = f.create_dataset('all_feats', [max_count] + list( feats.size()[1:]) , dtype='f')
  62. # all_feats_depth = f.create_dataset('all_feats_depth', [max_count] + list( feats_depth.size()[1:]) , dtype='f')
  63. all_feats[count:count+feats.size(0)] = feats.data.cpu().numpy()
  64. # all_feats_depth[count:count+feats_depth.size(0)] = feats_depth.data.cpu().numpy()
  65. all_labels[count:count+feats.size(0)] = y.cpu().numpy()
  66. count = count + feats.size(0)
  67. count_var = f.create_dataset('count', (1,), dtype='i')
  68. count_var[0] = count
  69. f.close()
  70. if __name__ == '__main__':
  71. params = parse_args('save_features')
  72. isAircraft = (params.dataset == 'aircrafts')
  73. assert params.method != 'maml' and params.method != 'maml_approx', 'maml do not support save_feature and run'
  74. if 'Conv' in params.model:
  75. if params.dataset in ['omniglot', 'cross_char']:
  76. image_size = 28
  77. else:
  78. image_size = 84
  79. else:
  80. # image_size = 224 #original setting
  81. # image_size = 256 #my setting
  82. image_size = params.image_size
  83. if params.dataset in ['omniglot', 'cross_char']:
  84. assert params.model == 'Conv4' and not params.train_aug ,'omniglot only support Conv4 without augmentation'
  85. params.model = 'Conv4S'
  86. split = params.split
  87. if params.dataset == 'cross':
  88. if split == 'base':
  89. loadfile = configs.data_dir['miniImagenet'] + 'all.json'
  90. else:
  91. loadfile = configs.data_dir['CUB'] + split +'.json'
  92. elif params.dataset == 'cross_char':
  93. if split == 'base':
  94. loadfile = configs.data_dir['omniglot'] + 'noLatin.json'
  95. else:
  96. loadfile = configs.data_dir['emnist'] + split +'.json'
  97. else:
  98. if params.json_seed is not None:
  99. loadfile = configs.data_dir[params.dataset] + split + params.json_seed + '.json'
  100. else:
  101. if '_' in params.dataset:
  102. loadfile = configs.data_dir[params.dataset.split('_')[0]] + split + '.json'
  103. else:
  104. loadfile = configs.data_dir[params.dataset] + split + '.json'
  105. if params.json_seed is not None:
  106. checkpoint_dir = '%s/checkpoints/%s_%s/%s_%s_%s' %(configs.save_dir, params.dataset, params.json_seed, params.date, params.model, params.method)
  107. else:
  108. checkpoint_dir = '%s/checkpoints/%s/%s_%s_%s' %(configs.save_dir, params.dataset, params.date, params.model, params.method)
  109. if params.train_aug:
  110. checkpoint_dir += '_aug'
  111. if not params.method in ['baseline', 'baseline++'] :
  112. checkpoint_dir += '_%dway_%dshot_%dquery' %( params.train_n_way, params.n_shot, params.n_query)
  113. checkpoint_dir += '_%d'%image_size
  114. ## Use another dataset (dataloader) for unlabeled data
  115. if params.dataset_unlabel is not None:
  116. checkpoint_dir += params.dataset_unlabel
  117. checkpoint_dir += str(params.bs)
  118. ## Use grey image
  119. if params.grey:
  120. checkpoint_dir += '_grey'
  121. ## Add jigsaw
  122. if params.jigsaw:
  123. checkpoint_dir += '_jigsawonly_alldata_lbda%.2f'%(params.lbda)
  124. checkpoint_dir += params.optimization
  125. ## Add rotation
  126. if params.rotation:
  127. checkpoint_dir += '_rotation_lbda%.2f'%(params.lbda)
  128. checkpoint_dir += params.optimization
  129. checkpoint_dir += '_lr%.4f'%(params.lr)
  130. if params.finetune:
  131. checkpoint_dir += '_finetune'
  132. if params.random:
  133. checkpoint_dir = 'checkpoints/CUB/random'
  134. print('checkpoint_dir:',checkpoint_dir)
  135. if params.loadfile != '':
  136. modelfile = params.loadfile
  137. checkpoint_dir = params.loadfile
  138. else:
  139. if params.save_iter != -1:
  140. modelfile = get_assigned_file(checkpoint_dir,params.save_iter)
  141. elif params.method in ['baseline', 'baseline++'] :
  142. modelfile = get_resume_file(checkpoint_dir)
  143. else:
  144. modelfile = get_best_file(checkpoint_dir)
  145. if params.save_iter != -1:
  146. outfile = os.path.join( checkpoint_dir.replace("checkpoints","features"), split + "_" + str(params.save_iter)+ ".hdf5")
  147. else:
  148. outfile = os.path.join( checkpoint_dir.replace("checkpoints","features"), split + ".hdf5")
  149. datamgr = SimpleDataManager(image_size, batch_size = params.test_bs, isAircraft=isAircraft, grey=params.grey)
  150. data_loader = datamgr.get_data_loader(loadfile, aug = False)
  151. if params.method in ['relationnet', 'relationnet_softmax']:
  152. if params.model == 'Conv4':
  153. model = backbone.Conv4NP()
  154. elif params.model == 'Conv6':
  155. model = backbone.Conv6NP()
  156. elif params.model == 'Conv4S':
  157. model = backbone.Conv4SNP()
  158. else:
  159. model = model_dict[params.model]( flatten = False )
  160. elif params.method in ['maml' , 'maml_approx']:
  161. raise ValueError('MAML do not support save feature')
  162. else:
  163. # import ipdb; ipdb.set_trace()
  164. train_few_shot_params = dict(n_way = params.train_n_way, n_support = params.n_shot, \
  165. jigsaw=params.jigsaw, lbda=params.lbda, rotation=params.rotation, tracking=params.tracking)
  166. if params.method == 'protonet':
  167. print("USE BN:",not params.no_bn)
  168. model = ProtoNet( model_dict[params.model], **train_few_shot_params , use_bn = (not params.no_bn))
  169. elif params.method == 'matchingnet':
  170. model = MatchingNet( model_dict[params.model], **train_few_shot_params )
  171. else:# baseline and baseline++
  172. if isinstance(model_dict[params.model],str):
  173. if model_dict[params.model] == 'resnet18':
  174. model = ResidualNet('ImageNet', 18, 1000, None)
  175. else:
  176. model = model_dict[params.model]()
  177. # model = model_dict[params.model]()
  178. # import ipdb; ipdb.set_trace()
  179. model = model.cuda()
  180. if params.method != 'baseline':
  181. model.feature = model.feature.cuda()
  182. tmp = torch.load(modelfile)
  183. state = tmp['state']
  184. state_keys = list(state.keys())
  185. for i, key in enumerate(state_keys):
  186. if "feature." in key:
  187. newkey = key.replace("feature.","") # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
  188. state[newkey] = state.pop(key)
  189. else:
  190. state.pop(key)
  191. # import ipdb; ipdb.set_trace()
  192. # if params.method != 'baseline':
  193. model.feature.load_state_dict(state)
  194. model.feature.eval()
  195. # else:
  196. # model.load_state_dict(state)
  197. model.eval()
  198. dirname = os.path.dirname(outfile)
  199. if not os.path.isdir(dirname):
  200. os.makedirs(dirname)
  201. # import ipdb; ipdb.set_trace()
  202. # outfile += '_finetune'
  203. print('outfile is', outfile)
  204. save_features(model, data_loader, outfile)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...