1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
|
- import warnings
- warnings.filterwarnings("ignore")
- import numpy as np
- import time
- import os
- import glob
- import random
- import sys
- from utils.io_utils import set_seed, parse_args
- params = parse_args('test')
- import torch
- import torch.nn as nn
- from torch.autograd import Variable
- import torch.optim
- import torch.optim.lr_scheduler as lr_scheduler
- set_seed(params.seed)
- import config.configs as configs
- import models.backbone as backbone
- from data.datamgr_2loss import SimpleDataManager, SetDataManager
- from methods.protonet_2loss import ProtoNet
-
- from utils.io_utils import model_dict, get_resume_file, get_best_file, get_assigned_file
- import json
- from models.model_resnet import *
- from utils.utils import RunningAverage, Logger, wandb_restore_models
- from tqdm import tqdm
- import wandb
- from data.cdfsl import Chest_few_shot
- from data.cdfsl import CropDisease_few_shot
- from data.cdfsl import EuroSAT_few_shot
- from data.cdfsl import ISIC_few_shot
- import csv
- out_file = open("other/cdfsl_results.txt", "a")
- log_file = open("other/cdfsl_results_logs.txt", "a")
- timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime())
- datamanagers = {"ISIC": ISIC_few_shot.SetDataManager, "EuroSAT": EuroSAT_few_shot.SetDataManager, \
- "Chest": Chest_few_shot.SetDataManager}
- dataloaders = {}
- for dset in datamanagers.keys():
- dataloaders[dset] = {}
- datamgr = datamanagers[dset](224, n_query = 16, n_eposide = 600, n_way = 5, n_support = 5)
- dataloaders[dset]["224"] = datamgr.get_data_loader(aug=False)
- datamgr = datamanagers[dset](84, n_query = 16, n_eposide = 600, n_way = 5, n_support = 5)
- dataloaders[dset]["84"] = datamgr.get_data_loader(aug=False)
- with open('other/runs.csv') as csv_file:
- csv_reader = csv.reader(csv_file, delimiter=',')
- line_count = 0
- for row in csv_reader:
- id = row[0]
- print(id)
- wandb.init(project="Table-2", entity="meta-learners", id=id, resume=True) # NOTE: Change when project="CDFSL"
- dir = wandb.config["checkpoint_dir"]
- dir = dir[dir.index("results"):]
- if len(id) == 0 or len(dir) == 0:
- continue
- image_size = wandb.config["image_size"]
- model_type = wandb.config["model"]
- params = wandb.config
- model = ProtoNet( model_dict[model_type], n_way=5, n_support=5, use_bn=(not params["no_bn"]), pretrain=params["pretrain"], tracking=params["tracking"],)
- try:
- for file in ["best_model.tar", "last_model.tar"]:
-
- full_path = os.path.join(dir, file)
- pth = wandb.restore(full_path)
- print("Restored %s" % (pth.name))
- tmp = torch.load(pth.name)
- state = tmp['state']
- state_keys = list(state.keys())
- for i, key in enumerate(state_keys):
- if "feature." in key:
- newkey = key.replace("feature.","") # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
- state[newkey] = state.pop(key)
- else:
- state.pop(key)
- model.feature.load_state_dict(state)
- model = model.cuda()
- model.feature = model.feature.cuda()
- model.feature.eval()
- model.eval()
- for dset in datamanagers.keys():
- print(dset, end=": ")
-
- acc_mean, acc_std = model.test_loop( dataloaders[dset][str(image_size)], proto_only=True)
- acc_str_c = '%4.2f%% +- %4.2f%%' %(acc_mean, 1.96* acc_std/np.sqrt(600))
- wandb.log({"cdfsl/%s_%s" % (dset, "best" if file=="best_model.tar" else "last") : acc_str_c})
- exp_setting = 'Time: %s, W&B ID: %s, Dataset: %s' %(timestamp, id, dset)
- acc_str = 'Test Acc: %s' %(acc_str_c)
- out_file.write( '%s %s\n' %(exp_setting,acc_str) )
- print("Removed %s" % (pth.name))
- os.remove(pth.name)
- wandb.finish()
- except ValueError as ve:
- print(ve)
- log_file.write("ValueError for %s: %s" % (id, ve))
- except RuntimeError as re:
- print(re)
- log_file.write("RuntimeError for %s: %s" % (id, re))
- except:
- print("Unexpected error:", sys.exc_info()[0])
- log_file.write("Unexpected for %s: %s" % (id, sys.exc_info()[0]))
- wandb.finish()
|