Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#647 Feature/sg 573 Integrate new EMA decay schedules

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-573-Integrate-EMA
@@ -16,7 +16,7 @@ ema: True
 ema_params:
 ema_params:
   decay: 0.9999
   decay: 0.9999
   beta: 15
   beta: 15
-  exp_activation: True
+  decay_type: exp
 
 
 train_metrics_list:
 train_metrics_list:
   - PixelAccuracy:
   - PixelAccuracy:
Discard
@@ -30,8 +30,8 @@ criterion_params: {} # when `loss` is one of SuperGradient's built in options, i
 ema: False # whether to use Model Exponential Moving Average
 ema: False # whether to use Model Exponential Moving Average
 ema_params: # parameters for the ema model.
 ema_params: # parameters for the ema model.
   decay: 0.9999
   decay: 0.9999
+  decay_type: exp
   beta: 15
   beta: 15
-  exp_activation: True
 
 
 
 
 train_metrics_list: [] # Metrics to log during training. For more information on torchmetrics see https://torchmetrics.rtfd.io/en/latest/.
 train_metrics_list: [] # Metrics to log during training. For more information on torchmetrics see https://torchmetrics.rtfd.io/en/latest/.
Discard
@@ -17,8 +17,8 @@ optimizer_params:
 
 
 ema: True
 ema: True
 ema_params:
 ema_params:
-  exp_activation: False
   decay: 0.9999
   decay: 0.9999
+  decay_type: constant
 
 
 loss: cross_entropy
 loss: cross_entropy
 criterion_params:
 criterion_params:
@@ -42,4 +42,3 @@ valid_metrics_list:                               # metrics for evaluation
   - Top5
   - Top5
 
 
 _convert_: all
 _convert_: all
-
Discard
@@ -17,7 +17,7 @@ optimizer_params:
 
 
 ema: True
 ema: True
 ema_params:
 ema_params:
-  exp_activation: False
+  decay_type: constant
   decay: 0.9999
   decay: 0.9999
 
 
 loss: cross_entropy
 loss: cross_entropy
Discard
@@ -1,21 +1,14 @@
+from typing import Union, Dict, Mapping, Any
+
 import hydra
 import hydra
 import torch.nn
 import torch.nn
 from omegaconf import DictConfig, OmegaConf
 from omegaconf import DictConfig, OmegaConf
 from torch.utils.data import DataLoader
 from torch.utils.data import DataLoader
 
 
-from super_gradients.training.utils.distributed_training_utils import setup_device
 from super_gradients.common import MultiGPUMode
 from super_gradients.common import MultiGPUMode
-from super_gradients.training.dataloaders import dataloaders
-from super_gradients.training.models import SgModule
-from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
-from super_gradients.training.models.kd_modules.kd_module import KDModule
-from super_gradients.training.sg_trainer import Trainer
-from typing import Union, Dict
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.training import utils as core_utils, models
 from super_gradients.training import utils as core_utils, models
-from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
-from super_gradients.training.utils import get_param, HpmStruct
-from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict, load_checkpoint_to_model
+from super_gradients.training.dataloaders import dataloaders
 from super_gradients.training.exceptions.kd_trainer_exceptions import (
 from super_gradients.training.exceptions.kd_trainer_exceptions import (
     ArchitectureKwargsException,
     ArchitectureKwargsException,
     UnsupportedKDArchitectureException,
     UnsupportedKDArchitectureException,
@@ -24,7 +17,15 @@ from super_gradients.training.exceptions.kd_trainer_exceptions import (
     TeacherKnowledgeException,
     TeacherKnowledgeException,
     UndefinedNumClassesException,
     UndefinedNumClassesException,
 )
 )
+from super_gradients.training.models import SgModule
+from super_gradients.training.models.all_architectures import KD_ARCHITECTURES
+from super_gradients.training.models.kd_modules.kd_module import KDModule
+from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
+from super_gradients.training.sg_trainer import Trainer
+from super_gradients.training.utils import get_param, HpmStruct
 from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
 from super_gradients.training.utils.callbacks import KDModelMetricsUpdateCallback
+from super_gradients.training.utils.checkpoint_utils import read_ckpt_state_dict, load_checkpoint_to_model
+from super_gradients.training.utils.distributed_training_utils import setup_device
 from super_gradients.training.utils.ema import KDModelEMA
 from super_gradients.training.utils.ema import KDModelEMA
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -255,17 +256,15 @@ class KDTrainer(Trainer):
         )
         )
         return hyper_param_config
         return hyper_param_config
 
 
-    def _instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True) -> KDModelEMA:
-        """Instantiate KD ema model for KDModule.
-
-        If the model is of class KDModule, the instance will be adapted to work on knowledge distillation.
-        :param decay:           the maximum decay value. as the training process advances, the decay will climb towards
-                                this value until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
-        :param beta:            the exponent coefficient. The higher the beta, the sooner in the training the decay will
-                                saturate to its final value. beta=15 is ~40% of the training process.
-        :param exp_activation:
+    def _instantiate_ema_model(self, ema_params: Mapping[str, Any]) -> KDModelEMA:
+        """Instantiate ema model for standard SgModule.
+        :param decay_type: (str) The decay climb schedule. See EMA_DECAY_FUNCTIONS for more details.
+        :param decay: The maximum decay value. As the training process advances, the decay will climb towards this value
+                      according to decay_type schedule. See EMA_DECAY_FUNCTIONS for more details.
+        :param kwargs: Additional parameters for the decay function. See EMA_DECAY_FUNCTIONS for more details.
         """
         """
-        return KDModelEMA(self.net, decay, beta, exp_activation)
+        logger.info(f"Using EMA with params {ema_params}")
+        return KDModelEMA.from_params(self.net, **ema_params)
 
 
     def _save_best_checkpoint(self, epoch, state):
     def _save_best_checkpoint(self, epoch, state):
         """
         """
Discard
@@ -1,41 +1,39 @@
 import inspect
 import inspect
 import os
 import os
 from copy import deepcopy
 from copy import deepcopy
-from typing import Union, Tuple, Mapping, Dict
 from pathlib import Path
 from pathlib import Path
+from typing import Union, Tuple, Mapping, Dict, Any
 
 
+import hydra
 import numpy as np
 import numpy as np
 import torch
 import torch
-import hydra
 from omegaconf import DictConfig
 from omegaconf import DictConfig
+from omegaconf import OmegaConf
+from piptools.scripts.sync import _get_installed_distributions
 from torch import nn
 from torch import nn
-from torch.utils.data import DataLoader, SequentialSampler
 from torch.cuda.amp import GradScaler, autocast
 from torch.cuda.amp import GradScaler, autocast
+from torch.utils.data import DataLoader, SequentialSampler
+from torch.utils.data.distributed import DistributedSampler
 from torchmetrics import MetricCollection
 from torchmetrics import MetricCollection
 from tqdm import tqdm
 from tqdm import tqdm
-from piptools.scripts.sync import _get_installed_distributions
-
-from torch.utils.data.distributed import DistributedSampler
 
 
-from super_gradients.training.datasets.samplers import InfiniteSampler, RepeatAugSampler
-
-from super_gradients.common.factories.callbacks_factory import CallbacksFactory
+from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.data_types.enum import MultiGPUMode, StrictLoad, EvaluationType
 from super_gradients.common.data_types.enum import MultiGPUMode, StrictLoad, EvaluationType
-from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.common.decorators.factory_decorator import resolve_param
 from super_gradients.common.decorators.factory_decorator import resolve_param
-from super_gradients.common.abstractions.abstract_logger import get_logger
+from super_gradients.common.environment.device_utils import device_config
+from super_gradients.common.factories.callbacks_factory import CallbacksFactory
 from super_gradients.common.factories.list_factory import ListFactory
 from super_gradients.common.factories.list_factory import ListFactory
 from super_gradients.common.factories.losses_factory import LossesFactory
 from super_gradients.common.factories.losses_factory import LossesFactory
 from super_gradients.common.factories.metrics_factory import MetricsFactory
 from super_gradients.common.factories.metrics_factory import MetricsFactory
+from super_gradients.common.factories.pre_launch_callbacks_factory import PreLaunchCallbacksFactory
 from super_gradients.common.sg_loggers import SG_LOGGERS
 from super_gradients.common.sg_loggers import SG_LOGGERS
 from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
 from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
 from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
 from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
 from super_gradients.training import utils as core_utils, models, dataloaders
 from super_gradients.training import utils as core_utils, models, dataloaders
-from super_gradients.training.models import SgModule
-from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
-from super_gradients.training.utils import sg_trainer_utils, get_param
-from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, log_main_training_params
+from super_gradients.training.datasets.datasets_utils import DatasetStatisticsTensorboardLogger
+from super_gradients.training.datasets.samplers import InfiniteSampler, RepeatAugSampler
 from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat
 from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat
+from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics.metric_utils import (
 from super_gradients.training.metrics.metric_utils import (
     get_metrics_titles,
     get_metrics_titles,
     get_metrics_results_tuple,
     get_metrics_results_tuple,
@@ -43,7 +41,30 @@ from super_gradients.training.metrics.metric_utils import (
     get_metrics_dict,
     get_metrics_dict,
     get_train_loop_description_dict,
     get_train_loop_description_dict,
 )
 )
+from super_gradients.training.models import SgModule
+from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.params import TrainingParams
 from super_gradients.training.params import TrainingParams
+from super_gradients.training.pretrained_models import PRETRAINED_NUM_CLASSES
+from super_gradients.training.utils import HpmStruct
+from super_gradients.training.utils import random_seed
+from super_gradients.training.utils import sg_trainer_utils, get_param
+from super_gradients.training.utils.callbacks import (
+    CallbackHandler,
+    Phase,
+    LR_SCHEDULERS_CLS_DICT,
+    PhaseContext,
+    MetricsUpdateCallback,
+    LR_WARMUP_CLS_DICT,
+    ContextSgMethods,
+    LRCallbackBase,
+)
+from super_gradients.training.utils.checkpoint_utils import (
+    get_ckpt_local_path,
+    read_ckpt_state_dict,
+    load_checkpoint_to_model,
+    load_pretrained_weights,
+    get_checkpoints_dir_path,
+)
 from super_gradients.training.utils.distributed_training_utils import (
 from super_gradients.training.utils.distributed_training_utils import (
     MultiGPUModeAutocastWrapper,
     MultiGPUModeAutocastWrapper,
     reduce_results_tuple_for_ddp,
     reduce_results_tuple_for_ddp,
@@ -59,34 +80,11 @@ from super_gradients.training.utils.distributed_training_utils import (
     DDPNotSetupException,
     DDPNotSetupException,
 )
 )
 from super_gradients.training.utils.ema import ModelEMA
 from super_gradients.training.utils.ema import ModelEMA
+from super_gradients.training.utils.hydra_utils import load_experiment_cfg, add_params_to_cfg
 from super_gradients.training.utils.optimizer_utils import build_optimizer
 from super_gradients.training.utils.optimizer_utils import build_optimizer
+from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, log_main_training_params
 from super_gradients.training.utils.utils import fuzzy_idx_in_list
 from super_gradients.training.utils.utils import fuzzy_idx_in_list
 from super_gradients.training.utils.weight_averaging_utils import ModelWeightAveraging
 from super_gradients.training.utils.weight_averaging_utils import ModelWeightAveraging
-from super_gradients.training.metrics import Accuracy, Top5
-from super_gradients.training.utils import random_seed
-from super_gradients.training.utils.checkpoint_utils import (
-    get_ckpt_local_path,
-    read_ckpt_state_dict,
-    load_checkpoint_to_model,
-    load_pretrained_weights,
-    get_checkpoints_dir_path,
-)
-from super_gradients.training.datasets.datasets_utils import DatasetStatisticsTensorboardLogger
-from super_gradients.training.utils.callbacks import (
-    CallbackHandler,
-    Phase,
-    LR_SCHEDULERS_CLS_DICT,
-    PhaseContext,
-    MetricsUpdateCallback,
-    LR_WARMUP_CLS_DICT,
-    ContextSgMethods,
-    LRCallbackBase,
-)
-from super_gradients.common.environment.device_utils import device_config
-from super_gradients.training.utils import HpmStruct
-from super_gradients.training.utils.hydra_utils import load_experiment_cfg, add_params_to_cfg
-from omegaconf import OmegaConf
-from super_gradients.common.factories.pre_launch_callbacks_factory import PreLaunchCallbacksFactory
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -547,9 +545,11 @@ class Trainer:
         self.phase_callback_handler.on_train_batch_backward_end(context)
         self.phase_callback_handler.on_train_batch_backward_end(context)
 
 
         # ACCUMULATE GRADIENT FOR X BATCHES BEFORE OPTIMIZING
         # ACCUMULATE GRADIENT FOR X BATCHES BEFORE OPTIMIZING
-        integrated_batches_num = batch_idx + len(self.train_loader) * epoch + 1
+        local_step = batch_idx + 1
+        global_step = local_step + len(self.train_loader) * epoch
+        total_steps = len(self.train_loader) * self.max_epochs
 
 
-        if integrated_batches_num % self.batch_accumulate == 0:
+        if global_step % self.batch_accumulate == 0:
             self.phase_callback_handler.on_train_batch_gradient_step_start(context)
             self.phase_callback_handler.on_train_batch_gradient_step_start(context)
 
 
             # APPLY GRADIENT CLIPPING IF REQUIRED
             # APPLY GRADIENT CLIPPING IF REQUIRED
@@ -563,7 +563,7 @@ class Trainer:
 
 
             self.optimizer.zero_grad()
             self.optimizer.zero_grad()
             if self.ema:
             if self.ema:
-                self.ema_model.update(self.net, integrated_batches_num / (len(self.train_loader) * self.max_epochs))
+                self.ema_model.update(self.net, step=global_step, total_steps=total_steps)
 
 
             # RUN PHASE CALLBACKS
             # RUN PHASE CALLBACKS
             self.phase_callback_handler.on_train_batch_gradient_step_end(context)
             self.phase_callback_handler.on_train_batch_gradient_step_end(context)
@@ -1083,9 +1083,7 @@ class Trainer:
         num_batches = len(self.train_loader)
         num_batches = len(self.train_loader)
 
 
         if self.ema:
         if self.ema:
-            ema_params = self.training_params.ema_params
-            logger.info(f"Using EMA with params {ema_params}")
-            self.ema_model = self._instantiate_ema_model(**ema_params)
+            self.ema_model = self._instantiate_ema_model(self.training_params.ema_params)
             self.ema_model.updates = self.start_epoch * num_batches // self.batch_accumulate
             self.ema_model.updates = self.start_epoch * num_batches // self.batch_accumulate
             if self.load_checkpoint:
             if self.load_checkpoint:
                 if "ema_net" in self.checkpoint.keys():
                 if "ema_net" in self.checkpoint.keys():
@@ -1903,14 +1901,15 @@ class Trainer:
 
 
         return net
         return net
 
 
-    def _instantiate_ema_model(self, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True) -> ModelEMA:
+    def _instantiate_ema_model(self, ema_params: Mapping[str, Any]) -> ModelEMA:
         """Instantiate ema model for standard SgModule.
         """Instantiate ema model for standard SgModule.
-        :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
-                      until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
-        :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
-                     its final value. beta=15 is ~40% of the training process.
+        :param decay_type: (str) The decay climb schedule. See EMA_DECAY_FUNCTIONS for more details.
+        :param decay: The maximum decay value. As the training process advances, the decay will climb towards this value
+                      according to decay_type schedule. See EMA_DECAY_FUNCTIONS for more details.
+        :param kwargs: Additional parameters for the decay function. See EMA_DECAY_FUNCTIONS for more details.
         """
         """
-        return ModelEMA(self.net, decay, beta, exp_activation)
+        logger.info(f"Using EMA with params {ema_params}")
+        return ModelEMA.from_params(self.net, **ema_params)
 
 
     @property
     @property
     def get_net(self):
     def get_net(self):
Discard
@@ -1,4 +1,3 @@
-import math
 import warnings
 import warnings
 from copy import deepcopy
 from copy import deepcopy
 from typing import Union
 from typing import Union
@@ -6,9 +5,14 @@ from typing import Union
 import torch
 import torch
 from torch import nn
 from torch import nn
 
 
+from super_gradients.common.abstractions.abstract_logger import get_logger
+from super_gradients.common.exceptions.factory_exceptions import UnknownTypeException
 from super_gradients.training import utils as core_utils
 from super_gradients.training import utils as core_utils
 from super_gradients.training.models import SgModule
 from super_gradients.training.models import SgModule
 from super_gradients.training.models.kd_modules.kd_module import KDModule
 from super_gradients.training.models.kd_modules.kd_module import KDModule
+from super_gradients.training.utils.ema_decay_schedules import IDecayFunction, EMA_DECAY_FUNCTIONS
+
+logger = get_logger(__name__)
 
 
 
 
 def copy_attr(a: nn.Module, b: nn.Module, include: Union[list, tuple] = (), exclude: Union[list, tuple] = ()):
 def copy_attr(a: nn.Module, b: nn.Module, include: Union[list, tuple] = (), exclude: Union[list, tuple] = ()):
@@ -30,7 +34,7 @@ class ModelEMA:
     GPU assignment and distributed training wrappers.
     GPU assignment and distributed training wrappers.
     """
     """
 
 
-    def __init__(self, model, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True):
+    def __init__(self, model, decay: float, decay_function: IDecayFunction):
         """
         """
         Init the EMA
         Init the EMA
         :param model: Union[SgModule, nn.Module], the training model to construct the EMA model by
         :param model: Union[SgModule, nn.Module], the training model to construct the EMA model by
@@ -44,10 +48,8 @@ class ModelEMA:
         # Create EMA
         # Create EMA
         self.ema = deepcopy(model)
         self.ema = deepcopy(model)
         self.ema.eval()
         self.ema.eval()
-        if exp_activation:
-            self.decay_function = lambda x: decay * (1 - math.exp(-x * beta))  # decay exponential ramp (to help early epochs)
-        else:
-            self.decay_function = lambda x: decay  # always return the same decay factor
+        self.decay = decay
+        self.decay_function = decay_function
 
 
         """"
         """"
         we hold a list of model attributes (not wights and biases) which we would like to include in each
         we hold a list of model attributes (not wights and biases) which we would like to include in each
@@ -65,15 +67,72 @@ class ModelEMA:
         for p in self.ema.module.parameters():
         for p in self.ema.module.parameters():
             p.requires_grad_(False)
             p.requires_grad_(False)
 
 
-    def update(self, model, training_percent: float):
+    @classmethod
+    def from_params(cls, model: nn.Module, decay_type: str = None, decay: float = None, **kwargs):
+        if decay is None:
+            logger.warning(
+                "Parameter `decay` is not specified for EMA params. Please specify `decay` parameter explicitly in your config:\n"
+                "ema: True\n"
+                "ema_params: \n"
+                "  decay: 0.9999\n"
+                "  decay_type: exp\n"
+                "  beta: 15\n"
+                "Will default to decay: 0.9999\n"
+                "In the next major release of SG this warning will become an error."
+            )
+            decay = 0.9999
+
+        if "exp_activation" in kwargs:
+            logger.warning(
+                "Parameter `exp_activation` is deprecated for EMA model. Please update your config to use decay_type: str (constant|exp|threshold) instead:\n"
+                "ema: True\n"
+                "ema_params: \n"
+                "  decay: 0.9999\n"
+                "  decay_type: exp # Equivalent to exp_activation: True\n"
+                "  beta: 15\n"
+                "\n"
+                "ema: True\n"
+                "ema_params: \n"
+                "  decay: 0.9999\n"
+                "  decay_type: constant # Equivalent to exp_activation: False\n"
+                "\n"
+                "In the next major release of SG this warning will become an error."
+            )
+            decay_type = "exp" if bool(kwargs.pop("exp_activation")) else "constant"
+
+        if decay_type is None:
+            logger.warning(
+                "Parameter decay_type is not specified for EMA model. Please specify decay_type parameter explicitly in your config:\n"
+                "ema: True\n"
+                "ema_params: \n"
+                "  decay: 0.9999\n"
+                "  decay_type: constant|exp|threshold\n"
+                "Will default to `exp` decay with beta = 15\n"
+                "In the next major release of SG this warning will become an error."
+            )
+            decay_type = "exp"
+            if "beta" not in kwargs:
+                kwargs["beta"] = 15
+
+        try:
+            decay_cls = EMA_DECAY_FUNCTIONS[decay_type]
+        except KeyError:
+            raise UnknownTypeException(decay_type, list(EMA_DECAY_FUNCTIONS.keys()))
+
+        decay_function = decay_cls(**kwargs)
+        return cls(model, decay, decay_function)
+
+    def update(self, model, step: int, total_steps: int):
         """
         """
         Update the state of the EMA model.
         Update the state of the EMA model.
-        :param model: current training model
-        :param training_percent: the percentage of the training process [0,1]. i.e 0.4 means 40% of the training have passed
+
+        :param model: Current training model
+        :param step: Current training step
+        :param total_steps: Total training steps
         """
         """
         # Update EMA parameters
         # Update EMA parameters
         with torch.no_grad():
         with torch.no_grad():
-            decay = self.decay_function(training_percent)
+            decay = self.decay_function(self.decay, step, total_steps)
 
 
             for ema_v, model_v in zip(self.ema.module.state_dict().values(), model.state_dict().values()):
             for ema_v, model_v in zip(self.ema.module.state_dict().values(), model.state_dict().values()):
                 if ema_v.dtype.is_floating_point:
                 if ema_v.dtype.is_floating_point:
@@ -101,7 +160,7 @@ class KDModelEMA(ModelEMA):
     GPU assignment and distributed training wrappers.
     GPU assignment and distributed training wrappers.
     """
     """
 
 
-    def __init__(self, kd_model: KDModule, decay: float = 0.9999, beta: float = 15, exp_activation: bool = True):
+    def __init__(self, kd_model: KDModule, decay: float, decay_function: IDecayFunction):
         """
         """
         Init the EMA
         Init the EMA
         :param kd_model: KDModule, the training Knowledge distillation model to construct the EMA model by
         :param kd_model: KDModule, the training Knowledge distillation model to construct the EMA model by
@@ -113,7 +172,7 @@ class KDModelEMA(ModelEMA):
                      its final value. beta=15 is ~40% of the training process.
                      its final value. beta=15 is ~40% of the training process.
         """
         """
         # Only work on the student (we don't want to update and to have a duplicate of the teacher)
         # Only work on the student (we don't want to update and to have a duplicate of the teacher)
-        super().__init__(model=core_utils.WrappedModel(kd_model.module.student), decay=decay, beta=beta, exp_activation=exp_activation)
+        super().__init__(model=core_utils.WrappedModel(kd_model.module.student), decay=decay, decay_function=decay_function)
 
 
         # Overwrite current ema attribute with combination of the student model EMA (current self.ema)
         # Overwrite current ema attribute with combination of the student model EMA (current self.ema)
         # with already the instantiated teacher, to have the final KD EMA
         # with already the instantiated teacher, to have the final KD EMA
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
  1. import math
  2. from abc import abstractmethod
  3. __all__ = ["IDecayFunction", "ConstantDecay", "ThresholdDecay", "ExpDecay", "EMA_DECAY_FUNCTIONS"]
  4. class IDecayFunction:
  5. """
  6. Interface for EMA decay schedule. The decay schedule is a function of the maximum decay value and training progress.
  7. Usually it gradually increase EMA from to the maximum value. The exact ramp-up schedule is defined by the concrete
  8. implementation.
  9. """
  10. @abstractmethod
  11. def __call__(self, decay: float, step: int, total_steps: int) -> float:
  12. """
  13. :param decay: The maximum decay value.
  14. :param step: Current training step. The unit-range training percentage can be obtained by `step / total_steps`.
  15. :param total_steps: Total number of training steps.
  16. :return: Computed decay value for a given step.
  17. """
  18. pass
  19. class ConstantDecay(IDecayFunction):
  20. """
  21. Constant decay schedule.
  22. """
  23. def __init__(self, **kwargs):
  24. pass
  25. def __call__(self, decay: float, step: int, total_steps: int) -> float:
  26. return decay
  27. class ThresholdDecay(IDecayFunction):
  28. """
  29. Gradually increase EMA decay from 0.1 to the maximum value using following formula: min(decay, (1 + step) / (10 + step))
  30. """
  31. def __init__(self, **kwargs):
  32. pass
  33. def __call__(self, decay: float, step, total_steps: int) -> float:
  34. return min(decay, (1 + step) / (10 + step))
  35. class ExpDecay(IDecayFunction):
  36. """
  37. Gradually increase EMA decay from 0.1 to the maximum value using following formula: decay * (1 - math.exp(-x * self.beta))
  38. """
  39. def __init__(self, beta: float, **kwargs):
  40. self.beta = beta
  41. def __call__(self, decay: float, step, total_steps: int) -> float:
  42. x = step / total_steps
  43. return decay * (1 - math.exp(-x * self.beta))
  44. EMA_DECAY_FUNCTIONS = {"constant": ConstantDecay, "threshold": ThresholdDecay, "exp": ExpDecay}
Discard