Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

ema.py 4.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
  1. import math
  2. import warnings
  3. from copy import deepcopy
  4. from typing import Union
  5. import torch
  6. from torch import nn
  7. from super_gradients.training.models import SgModule
  8. def copy_attr(a: nn.Module, b: nn.Module, include: Union[list, tuple] = (), exclude: Union[list, tuple] = ()):
  9. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  10. for k, v in b.__dict__.items():
  11. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  12. continue
  13. else:
  14. setattr(a, k, v)
  15. class ModelEMA:
  16. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  17. Keep a moving average of everything in the model state_dict (parameters and buffers).
  18. This is intended to allow functionality like
  19. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  20. A smoothed version of the weights is necessary for some training schemes to perform well.
  21. This class is sensitive where it is initialized in the sequence of model init,
  22. GPU assignment and distributed training wrappers.
  23. """
  24. def __init__(self, model, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True):
  25. """
  26. Init the EMA
  27. :param model: Union[SgModule, nn.Module], the training model to construct the EMA model by
  28. IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE
  29. AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED (SEE
  30. YoLoV5Base IMPLEMENTATION IN super_gradients.trainer.models.yolov5.py AS AN EXAMPLE).
  31. :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
  32. until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
  33. :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
  34. its final value. beta=15 is ~40% of the training process.
  35. """
  36. # Create EMA
  37. self.ema = deepcopy(model)
  38. self.ema.eval()
  39. if exp_activation:
  40. self.decay_function = lambda x: decay * (1 - math.exp(-x * beta)) # decay exponential ramp (to help early epochs)
  41. else:
  42. self.decay_function = lambda x: decay # always return the same decay factor
  43. """"
  44. we hold a list of model attributes (not wights and biases) which we would like to include in each
  45. attribute update or exclude from each update. a SgModule declare these attribute using
  46. get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule
  47. all non-private (not starting with '_') attributes will be updated (and only them).
  48. """
  49. if isinstance(model.module, SgModule):
  50. self.include_attributes = model.module.get_include_attributes()
  51. self.exclude_attributes = model.module.get_exclude_attributes()
  52. else:
  53. warnings.warn("Warning: EMA should be used with SgModule instance. All attributes of the model will be "
  54. "included in EMA")
  55. self.include_attributes = []
  56. self.exclude_attributes = []
  57. for p in self.ema.module.parameters():
  58. p.requires_grad_(False)
  59. def update(self, model, training_percent: float):
  60. """
  61. Update the state of the EMA model.
  62. :param model: current training model
  63. :param training_percent: the percentage of the training process [0,1]. i.e 0.4 means 40% of the training have passed
  64. """
  65. # Update EMA parameters
  66. with torch.no_grad():
  67. decay = self.decay_function(training_percent)
  68. for ema_v, model_v in zip(self.ema.module.state_dict().values(), model.state_dict().values()):
  69. if ema_v.dtype.is_floating_point:
  70. ema_v.copy_(ema_v * decay + (1. - decay) * model_v.detach())
  71. def update_attr(self, model):
  72. """
  73. This function updates model attributes (not weight and biases) from original model to the ema model.
  74. attributes of the original model, such as anchors and grids (of detection models), may be crucial to the
  75. model operation and need to be updated.
  76. If include_attributes and exclude_attributes lists were not defined, all non-private (not starting with '_')
  77. attributes will be updated (and only them).
  78. :param model: the source model
  79. """
  80. copy_attr(self.ema.module, model.module, self.include_attributes, self.exclude_attributes)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...