Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

io_utils.py 11 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
  1. import numpy as np
  2. import random
  3. import torch
  4. import os
  5. import glob
  6. import argparse
  7. import models.backbone as backbone
  8. from models.model_resnet import *
  9. model_dict = dict(
  10. Conv4 = backbone.Conv4,
  11. Conv4S = backbone.Conv4S,
  12. Conv6 = backbone.Conv6,
  13. ResNet10 = backbone.ResNet10,
  14. ResNet18 = backbone.ResNet18,
  15. ResNet34 = backbone.ResNet34,
  16. ResNet50 = backbone.ResNet50,
  17. ResNet101 = backbone.ResNet101,
  18. resnet18 = 'resnet18',
  19. resnet18_pytorch = 'resnet18_pytorch',
  20. resnet50_pytorch = 'resnet50_pytorch'
  21. )
  22. def str2bool(v):
  23. if isinstance(v, bool):
  24. return v
  25. if v.lower() in ('yes', 'true', 't', 'y', '1', 'True'):
  26. return True
  27. elif v.lower() in ('no', 'false', 'f', 'n', '0', 'False'):
  28. return False
  29. else:
  30. raise argparse.ArgumentTypeError('Boolean value expected.')
  31. def parse_args(script):
  32. parser = argparse.ArgumentParser(description= 'few-shot script %s' %(script))
  33. parser.add_argument('--dataset' , default='CUB', help='CUB/miniImagenet/cross/omniglot/cross_char')
  34. parser.add_argument('--model' , default='Conv4', help='model: Conv{4|6} / ResNet{10|18|34|50|101}') # 50 and 101 are not used in the paper
  35. parser.add_argument('--method' , default='baseline', help='baseline/baseline++/protonet/matchingnet/relationnet{_softmax}/maml{_approx}') #relationnet_softmax replace L2 norm with softmax to expedite training, maml_approx use first-order approximation in the gradient for efficiency
  36. parser.add_argument('--train_n_way' , default=5, type=int, help='class num to classify for training') #baseline and baseline++ would ignore this parameter
  37. parser.add_argument('--test_n_way' , default=5, type=int, help='class num to classify for testing (validation) ') #baseline and baseline++ only use this parameter in finetuning
  38. parser.add_argument('--n_shot' , default=5, type=int, help='number of labeled data in each class, same as n_support') #baseline and baseline++ only use this parameter in finetuning
  39. parser.add_argument('--train_aug' , type=str2bool, nargs='?', default=True, const=True, help='perform data augmentation or not during training ') #still required for save_features.py and test.py to find the model path correctly
  40. parser.add_argument('--jigsaw' , type=str2bool, nargs='?', default=False, const=True, help='multi-task training')
  41. parser.add_argument('--lbda' , default=0.0, type=float, help='lambda for the jigsaw loss, (1-lambda) for proto loss')
  42. # parser.add_argument('--lbda_proto' , default=1.0, type=float, help='lambda for the protonet loss')
  43. parser.add_argument('--lr' , default=0.001, type=float, help='learning rate')
  44. parser.add_argument('--optimization', default='Adam', type=str, help='Adam or SGD')
  45. parser.add_argument('--loadfile' , default='', type=str, help='load pre-trained model')
  46. parser.add_argument('--finetune' , action='store_true', help='finetuning from jigsaw to protonet')
  47. parser.add_argument('--random' , action='store_true', help='random init net')
  48. parser.add_argument('--n_query' , default=16, type=int, help='number of query, 16 is used in the paper')
  49. parser.add_argument('--image_size' , default=224, type=int, help='224 is used in the paper')
  50. parser.add_argument('--debug' , action='store_true', help='')
  51. parser.add_argument('--json_seed' , default=None, type=str, help='seed for CUB split')
  52. parser.add_argument('--date' , default='', type=str, help='date of the exp')
  53. parser.add_argument('--rotation' , type=str2bool, nargs='?', default=False, const=True, help='multi-task training')
  54. parser.add_argument('--grey' , action='store_true', help='use grey image') # Use for CUB, dogs and flowers only
  55. parser.add_argument('--low_res', type=str2bool, nargs='?', default=False, const=True, help='semi_sup') # Use for cars and aircrafts only
  56. parser.add_argument('--firstk' , default=0, type=int, help='first k images per class for training CUB')
  57. parser.add_argument('--testiter' , default=199, type=int, help='date of the exp')
  58. parser.add_argument('--wd' , default=0.01, type=float, help='weight decay, 0.01 to 0.00001')
  59. parser.add_argument('--bs' , default=16, type=int, help='batch size for baseline, 256 for fgvc?')
  60. parser.add_argument('--iterations' , default=20000, type=int, help='number of iterations')
  61. parser.add_argument('--useVal' , action='store_true', help='use val set as test set')
  62. parser.add_argument('--scheduler' , type=str2bool, nargs='?', default=False, const=True, help='lr scheduler')
  63. # parser.add_argument('--step_size' , default=10000, type=int, help='step for step scheduler')
  64. # parser.add_argument('--gamma' , default=0.2, type=float, help='gamma for step scheduler')
  65. parser.add_argument('--lbda_jigsaw' , default=0.0, type=float, help='lambda for the jigsaw loss, (1-lambda) for proto loss')
  66. parser.add_argument('--lbda_rotation' , default=0.0, type=float, help='lambda for the jigsaw loss, (1-lambda) for proto loss')
  67. parser.add_argument('--pretrain' , type=str2bool, nargs='?', default=False, const=True, help='use imagenet pre-train model')
  68. parser.add_argument('--dataset_unlabel' , default=None, help='CUB/miniImagenet/cross/omniglot/cross_char')
  69. parser.add_argument('--dataset_unlabel_percentage' , default="", help='20,40,60,80')
  70. parser.add_argument('--dataset_percentage' , default="", help='20,40,60,80')
  71. parser.add_argument('--bn_type', default=1, type=int, help="1 for BN+Tracking. 2 for BN + no tracking, 3 for no BN. BN --> BatchNorm")
  72. parser.add_argument('--test_bs' , default=64, type=int, help='batch size for testing w/o batchnorm')
  73. parser.add_argument('--split' , default='novel', help='base/val/novel') #default novel, but you can also test base/val class accuracy if you want
  74. parser.add_argument('--save_iter', default=-1, type=int,help ='saved feature from the model trained in x epoch, use the best model if x is -1')
  75. parser.add_argument('--adaptation' , action='store_true', help='further adaptation in test time or not')
  76. parser.add_argument('--device', type=str, default="0", help='GPU')
  77. parser.add_argument('--seed', type=int, default=42)
  78. parser.add_argument('--amp', type=str2bool, nargs='?', default=False, const=True, help='amp')
  79. if script == 'train':
  80. parser.add_argument('--num_classes' , default=200, type=int, help='total number of classes in softmax, only used in baseline') #make it larger than the maximum label value in base class
  81. parser.add_argument('--save_freq' , default=100, type=int, help='Save frequency')
  82. parser.add_argument('--start_epoch' , default=0, type=int,help ='Starting epoch')
  83. parser.add_argument('--stop_epoch' , default=400, type=int, help ='Stopping epoch') # for meta-learning methods, each epoch contains 100 episodes
  84. parser.add_argument('--resume' , action='store_true', help='continue from previous trained model with largest epoch')
  85. parser.add_argument('--warmup' , action='store_true', help='continue from baseline, neglected if resume is true') #never used in the paper
  86. parser.add_argument('--eval_interval', type=int, default=50, help='eval_interval')
  87. parser.add_argument('--run_name', default=None, help="wandb run name")
  88. parser.add_argument('--run_id', default=None, help="wandb run ID")
  89. parser.add_argument('--semi_sup', type=str2bool, nargs='?', default=False, const=True, help='semi_sup')
  90. parser.add_argument('--sup_ratio', type=float, default=1.0)
  91. parser.add_argument('--only_test', type=str2bool, nargs='?', default=False, const=True)
  92. parser.add_argument('--project', type=str, default="FSL-SSL")
  93. parser.add_argument('--save_model', type=str2bool, nargs='?', default=True, const=True)
  94. parser.add_argument('--demo', type=str2bool, nargs='?', default=False, const=True) # if True, train = 1 epoch, all episodes are 5 episodes (train,val,test)
  95. parser.add_argument('--only_train', type=str2bool, nargs='?', default=False, const=True) # if True, only train
  96. parser.add_argument('--sweep', type=str2bool, nargs='?', default=False, const=True) # if True, train = 1 epoch, all episodes are 5 episodes (train,val,test)
  97. return parser.parse_args()
  98. def get_assigned_file(checkpoint_dir,num):
  99. assign_file = os.path.join(checkpoint_dir, '{:d}.tar'.format(num))
  100. return assign_file
  101. def get_resume_file(checkpoint_dir):
  102. resume_file = os.path.join(checkpoint_dir, 'last_model.tar')
  103. return resume_file
  104. def get_best_file(checkpoint_dir):
  105. best_file = os.path.join(checkpoint_dir, 'best_model.tar')
  106. if os.path.isfile(best_file):
  107. return best_file
  108. else:
  109. return get_resume_file(checkpoint_dir)
  110. class data_prefetcher():
  111. def __init__(self, loader):
  112. self.loader = iter(loader)
  113. self.stream = torch.cuda.Stream()
  114. self.preload()
  115. def preload(self):
  116. try:
  117. self.inputs = list(next(self.loader))
  118. except StopIteration:
  119. self.inputs = None
  120. return
  121. with torch.cuda.stream(self.stream):
  122. for i,tensor in enumerate(self.inputs):
  123. self.inputs[i] = self.inputs[i].cuda(non_blocking=True)
  124. def next(self):
  125. torch.cuda.current_stream().wait_stream(self.stream)
  126. input = self.inputs[0]
  127. target = self.inputs[1]
  128. aux_input = self.inputs[2] if len(self.inputs) >= 4 else None
  129. aux_label = self.inputs[3] if len(self.inputs) >= 4 else None
  130. aux_input_2 = self.inputs[4] if len(self.inputs) >= 5 else None
  131. aux_label_2 = self.inputs[5] if len(self.inputs) >= 6 else None
  132. if input is not None:
  133. input.record_stream(torch.cuda.current_stream())
  134. if target is not None:
  135. target.record_stream(torch.cuda.current_stream())
  136. if aux_input is not None:
  137. aux_input.record_stream(torch.cuda.current_stream())
  138. if aux_label is not None:
  139. aux_label.record_stream(torch.cuda.current_stream())
  140. if aux_input_2 is not None:
  141. aux_input_2.record_stream(torch.cuda.current_stream())
  142. if aux_label_2 is not None:
  143. aux_label_2.record_stream(torch.cuda.current_stream())
  144. self.preload()
  145. return input, target, aux_input, aux_label, aux_input_2, aux_label_2
  146. def set_seed(seed):
  147. random.seed(seed)
  148. np.random.seed(seed)
  149. torch.manual_seed(seed)
  150. if torch.cuda.is_available():
  151. torch.cuda.manual_seed(seed)
  152. torch.cuda.manual_seed_all(seed)
  153. torch.backends.cudnn.deterministic = True
  154. torch.backends.cudnn.benchmark = False
  155. os.environ["PYTHONHASHSEED"] = str(seed)
  156. if __name__ == "__main__":
  157. args = parse_args('train')
  158. print(args)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...