Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

weight_averaging_utils.py 6.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
  1. import os
  2. import torch
  3. import numpy as np
  4. import pkg_resources
  5. from super_gradients.training import utils as core_utils
  6. from super_gradients.training.utils.utils import move_state_dict_to_device
  7. class ModelWeightAveraging:
  8. """
  9. Utils class for managing the averaging of the best several snapshots into a single model.
  10. A snapshot dictionary file and the average model will be saved / updated at every epoch and evaluated only when
  11. training is completed. The snapshot file will only be deleted upon completing the training.
  12. The snapshot dict will be managed on cpu.
  13. """
  14. def __init__(self, ckpt_dir,
  15. greater_is_better,
  16. source_ckpt_folder_name=None, metric_to_watch='acc',
  17. metric_idx=1, load_checkpoint=False,
  18. number_of_models_to_average=10,
  19. model_checkpoints_location='local'
  20. ):
  21. """
  22. Init the ModelWeightAveraging
  23. :param checkpoint_dir: the directory where the checkpoints are saved
  24. :param metric_to_watch: monitoring loss or acc, will be identical to that which determines best_model
  25. :param metric_idx:
  26. :param load_checkpoint: whether to load pre-existing snapshot dict.
  27. :param number_of_models_to_average: number of models to average
  28. """
  29. if source_ckpt_folder_name is not None:
  30. source_ckpt_file = os.path.join(source_ckpt_folder_name, 'averaging_snapshots.pkl')
  31. source_ckpt_file = pkg_resources.resource_filename('checkpoints', source_ckpt_file)
  32. self.averaging_snapshots_file = os.path.join(ckpt_dir, 'averaging_snapshots.pkl')
  33. self.number_of_models_to_average = number_of_models_to_average
  34. self.metric_to_watch = metric_to_watch
  35. self.metric_idx = metric_idx
  36. self.greater_is_better = greater_is_better
  37. # if continuing training, copy previous snapshot dict if exist
  38. if load_checkpoint and source_ckpt_folder_name is not None and os.path.isfile(source_ckpt_file):
  39. averaging_snapshots_dict = core_utils.load_checkpoint(ckpt_destination_dir=ckpt_dir,
  40. source_ckpt_folder_name=source_ckpt_folder_name,
  41. ckpt_filename="averaging_snapshots.pkl",
  42. load_weights_only=False,
  43. model_checkpoints_location=model_checkpoints_location,
  44. overwrite_local_ckpt=True)
  45. else:
  46. averaging_snapshots_dict = {'snapshot' + str(i): None for i in range(self.number_of_models_to_average)}
  47. # if metric to watch is acc, hold a zero array, if loss hold inf array
  48. if self.greater_is_better:
  49. averaging_snapshots_dict['snapshots_metric'] = -1 * np.inf * np.ones(self.number_of_models_to_average)
  50. else:
  51. averaging_snapshots_dict['snapshots_metric'] = np.inf * np.ones(self.number_of_models_to_average)
  52. torch.save(averaging_snapshots_dict, self.averaging_snapshots_file)
  53. def update_snapshots_dict(self, model, validation_results_tuple):
  54. """
  55. Update the snapshot dict and returns the updated average model for saving
  56. :param model: the latest model
  57. :param validation_results_tuple: performance of the latest model
  58. """
  59. averaging_snapshots_dict = self._get_averaging_snapshots_dict()
  60. # IF CURRENT MODEL IS BETTER, TAKING HIS PLACE IN ACC LIST AND OVERWRITE THE NEW AVERAGE
  61. require_update, update_ind = self._is_better(averaging_snapshots_dict, validation_results_tuple)
  62. if require_update:
  63. # moving state dict to cpu
  64. new_sd = model.state_dict()
  65. new_sd = move_state_dict_to_device(new_sd, 'cpu')
  66. averaging_snapshots_dict['snapshot' + str(update_ind)] = new_sd
  67. averaging_snapshots_dict['snapshots_metric'][update_ind] = validation_results_tuple[self.metric_idx]
  68. return averaging_snapshots_dict
  69. def get_average_model(self, model, validation_results_tuple=None):
  70. """
  71. Returns the averaged model
  72. :param model: will be used to determine arch
  73. :param validation_results_tuple: if provided, will update the average model before returning
  74. :param target_device: if provided, return sd on target device
  75. """
  76. # If validation tuple is provided, update the average model
  77. if validation_results_tuple is not None:
  78. averaging_snapshots_dict = self.update_snapshots_dict(model, validation_results_tuple)
  79. else:
  80. averaging_snapshots_dict = self._get_averaging_snapshots_dict()
  81. torch.save(averaging_snapshots_dict, self.averaging_snapshots_file)
  82. average_model_sd = averaging_snapshots_dict['snapshot0']
  83. for n_model in range(1, self.number_of_models_to_average):
  84. if averaging_snapshots_dict['snapshot' + str(n_model)] is not None:
  85. net_sd = averaging_snapshots_dict['snapshot' + str(n_model)]
  86. # USING MOVING AVERAGE
  87. for key in average_model_sd:
  88. average_model_sd[key] = torch.true_divide(
  89. average_model_sd[key] * n_model + net_sd[key],
  90. (n_model + 1))
  91. return average_model_sd
  92. def cleanup(self):
  93. """
  94. Delete snapshot file when reaching the last epoch
  95. """
  96. os.remove(self.averaging_snapshots_file)
  97. def _is_better(self, averaging_snapshots_dict, validation_results_tuple):
  98. """
  99. Determines if the new model is better according to the specified metrics
  100. :param averaging_snapshots_dict: snapshot dict
  101. :param validation_results_tuple: latest model performance
  102. """
  103. snapshot_metric_array = averaging_snapshots_dict['snapshots_metric']
  104. val = validation_results_tuple[self.metric_idx]
  105. if self.greater_is_better:
  106. update_ind = np.argmin(snapshot_metric_array)
  107. else:
  108. update_ind = np.argmax(snapshot_metric_array)
  109. if (self.greater_is_better and val > snapshot_metric_array[update_ind]) or (
  110. not self.greater_is_better and val < snapshot_metric_array[update_ind]):
  111. return True, update_ind
  112. return False, None
  113. def _get_averaging_snapshots_dict(self):
  114. return torch.load(self.averaging_snapshots_file)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...