Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#763 Added time factor to pre launch callback - batch size selector

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-685_time_factor_for_batch_cb
@@ -9,6 +9,8 @@ from super_gradients import is_distributed
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.common.abstractions.abstract_logger import get_logger
 from super_gradients.training import models
 from super_gradients.training import models
 from torch.distributed import barrier
 from torch.distributed import barrier
+import cv2
+import numpy as np
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -31,7 +33,7 @@ class AutoTrainBatchSizeSelectionCallback(PreLaunchCallback):
     AutoTrainBatchSizeSelectionCallback
     AutoTrainBatchSizeSelectionCallback
 
 
     Modifies cfg.dataset_params.train_dataloader_params.batch_size by searching for the maximal batch size that fits
     Modifies cfg.dataset_params.train_dataloader_params.batch_size by searching for the maximal batch size that fits
-     gpu memory. Works out of the box for DDP.
+     gpu memory/ the one resulting in fastest time for the selected number of train datalaoder iterations. Works out of the box for DDP.
 
 
     The search is done by running a few forward passes for increasing batch sizes, until CUDA OUT OF MEMORY is raised:
     The search is done by running a few forward passes for increasing batch sizes, until CUDA OUT OF MEMORY is raised:
 
 
@@ -68,14 +70,19 @@ class AutoTrainBatchSizeSelectionCallback(PreLaunchCallback):
 
 
     :param scale_lr: bool, whether to linearly scale cfg.training_hyperparams.initial_lr, i.e multiply by
     :param scale_lr: bool, whether to linearly scale cfg.training_hyperparams.initial_lr, i.e multiply by
      FOUND_BATCH_SIZE/cfg.dataset_params.train_datalaoder_params.batch_size (default=True)
      FOUND_BATCH_SIZE/cfg.dataset_params.train_datalaoder_params.batch_size (default=True)
+    :param mode: str, one of ["fastest","largest"], whether to select the largest batch size that fits memory or the one
+     that the resulted in overall fastest execution.
     """
     """
 
 
-    def __init__(self, min_batch_size: int, size_step: int, num_forward_passes: int = 3, max_batch_size=None, scale_lr: bool = True):
+    def __init__(self, min_batch_size: int, size_step: int, num_forward_passes: int = 3, max_batch_size=None, scale_lr: bool = True, mode: str = "fastest"):
+        if mode not in ["fastest", "largest"]:
+            raise TypeError(f"Expected mode to be one of: ['fastest','largest'], got {mode}")
         self.scale_lr = scale_lr
         self.scale_lr = scale_lr
         self.min_batch_size = min_batch_size
         self.min_batch_size = min_batch_size
         self.size_step = size_step
         self.size_step = size_step
         self.max_batch_size = max_batch_size
         self.max_batch_size = max_batch_size
         self.num_forward_passes = num_forward_passes
         self.num_forward_passes = num_forward_passes
+        self.mode = mode
 
 
     def __call__(self, cfg: DictConfig) -> DictConfig:
     def __call__(self, cfg: DictConfig) -> DictConfig:
 
 
@@ -104,11 +111,22 @@ class AutoTrainBatchSizeSelectionCallback(PreLaunchCallback):
         tmp_cfg.training_hyperparams.kill_ddp_pgroup_on_end = False
         tmp_cfg.training_hyperparams.kill_ddp_pgroup_on_end = False
         tmp_cfg.pre_launch_callbacks_list = []
         tmp_cfg.pre_launch_callbacks_list = []
 
 
-        while True:
+        fastest_batch_time = np.inf
+        fastest_batch_size = curr_batch_size
+
+        bs_found = False
+
+        while not bs_found:
             tmp_cfg.dataset_params.train_dataloader_params.batch_size = curr_batch_size
             tmp_cfg.dataset_params.train_dataloader_params.batch_size = curr_batch_size
 
 
             try:
             try:
+                passes_start = cv2.getTickCount()
                 Trainer.train_from_config(tmp_cfg)
                 Trainer.train_from_config(tmp_cfg)
+                curr_batch_time = (cv2.getTickCount() - passes_start) / cv2.getTickFrequency()
+                logger.info(f"Batch size = {curr_batch_size} time for {self.num_forward_passes} forward passes: {curr_batch_time} seconds.")
+                if curr_batch_time < fastest_batch_time:
+                    fastest_batch_size = curr_batch_size
+                    fastest_batch_time = curr_batch_time
 
 
             except RuntimeError as e:
             except RuntimeError as e:
                 if "out of memory" in str(e):
                 if "out of memory" in str(e):
@@ -116,26 +134,32 @@ class AutoTrainBatchSizeSelectionCallback(PreLaunchCallback):
                         logger.error("Ran out of memory for the smallest batch, try setting smaller min_batch_size.")
                         logger.error("Ran out of memory for the smallest batch, try setting smaller min_batch_size.")
                         raise e
                         raise e
                     else:
                     else:
-                        logger.info(f"Ran out of memory for {curr_batch_size}, setting batch size to {curr_batch_size - self.size_step}.")
-                        self._adapt_lr_if_needed(cfg, found_batch_size=curr_batch_size - self.size_step)
-                        cfg.dataset_params.train_dataloader_params.batch_size = curr_batch_size - self.size_step
-                        self._clear_model_gpu_mem(model)
-                        return cfg
+                        selected_batch_size = curr_batch_size - self.size_step if self.mode == "largest" else fastest_batch_size
+                        msg = f"Ran out of memory for {curr_batch_size}, setting batch size to {selected_batch_size}."
+                        bs_found = True
                 else:
                 else:
                     raise e
                     raise e
 
 
             else:
             else:
                 if self.max_batch_size is not None and curr_batch_size >= self.max_batch_size:
                 if self.max_batch_size is not None and curr_batch_size >= self.max_batch_size:
-                    logger.info(
-                        f"Did not run out of memory for {curr_batch_size} >= max_batch_size={self.max_batch_size}, " f"setting batch to {self.max_batch_size}."
+                    selected_batch_size = self.max_batch_size if self.mode == "largest" else fastest_batch_size
+                    msg = (
+                        f"Did not run out of memory for {curr_batch_size} >= max_batch_size={self.max_batch_size}, " f"setting batch to {selected_batch_size}."
                     )
                     )
-                    self._adapt_lr_if_needed(cfg, found_batch_size=self.max_batch_size)
-                    cfg.dataset_params.train_dataloader_params.batch_size = self.max_batch_size
+                    bs_found = True
+                else:
+                    logger.info(f"Did not run out of memory for {curr_batch_size}, retrying batch {curr_batch_size + self.size_step}.")
+                    curr_batch_size += self.size_step
                     self._clear_model_gpu_mem(model)
                     self._clear_model_gpu_mem(model)
-                    return cfg
-                logger.info(f"Did not run out of memory for {curr_batch_size}, retrying batch {curr_batch_size + self.size_step}.")
-                curr_batch_size += self.size_step
-                self._clear_model_gpu_mem(model)
+
+        return self._inject_selected_batch_size_to_config(cfg, model, msg, selected_batch_size)
+
+    def _inject_selected_batch_size_to_config(self, cfg, model, msg, selected_batch_size):
+        logger.info(msg)
+        self._adapt_lr_if_needed(cfg, found_batch_size=selected_batch_size)
+        cfg.dataset_params.train_dataloader_params.batch_size = selected_batch_size
+        self._clear_model_gpu_mem(model)
+        return cfg
 
 
     def _adapt_lr_if_needed(self, cfg: DictConfig, found_batch_size: int) -> DictConfig:
     def _adapt_lr_if_needed(self, cfg: DictConfig, found_batch_size: int) -> DictConfig:
         if self.scale_lr:
         if self.scale_lr:
Discard
@@ -73,6 +73,7 @@ class TestAutoBatchSelectionSingleGPU(unittest.TestCase):
                                 "max_batch_size": 64,
                                 "max_batch_size": 64,
                                 "num_forward_passes": 3,
                                 "num_forward_passes": 3,
                                 "scale_lr": False,
                                 "scale_lr": False,
+                                "mode": "largest",
                             }
                             }
                         }
                         }
                     ),
                     ),
@@ -93,7 +94,9 @@ class TestAutoBatchSelectionSingleGPU(unittest.TestCase):
             OmegaConf.set_struct(cfg, True)
             OmegaConf.set_struct(cfg, True)
             with open_dict(cfg):
             with open_dict(cfg):
                 cfg.pre_launch_callbacks_list = [
                 cfg.pre_launch_callbacks_list = [
-                    OmegaConf.create({"AutoTrainBatchSizeSelectionCallback": {"min_batch_size": 64, "size_step": 10000, "num_forward_passes": 3}}),
+                    OmegaConf.create(
+                        {"AutoTrainBatchSizeSelectionCallback": {"min_batch_size": 64, "size_step": 10000, "num_forward_passes": 3, "mode": "largest"}}
+                    ),
                     OmegaConf.create({"PreLaunchTrainBatchSizeVerificationCallback": {"batch_size": 64}}),
                     OmegaConf.create({"PreLaunchTrainBatchSizeVerificationCallback": {"batch_size": 64}}),
                     OmegaConf.create(
                     OmegaConf.create(
                         {
                         {
@@ -117,7 +120,15 @@ class TestAutoBatchSelectionSingleGPU(unittest.TestCase):
             with open_dict(cfg):
             with open_dict(cfg):
                 cfg.pre_launch_callbacks_list = [
                 cfg.pre_launch_callbacks_list = [
                     OmegaConf.create(
                     OmegaConf.create(
-                        {"AutoTrainBatchSizeSelectionCallback": {"min_batch_size": 32, "size_step": 32, "max_batch_size": 64, "num_forward_passes": 3}}
+                        {
+                            "AutoTrainBatchSizeSelectionCallback": {
+                                "min_batch_size": 32,
+                                "size_step": 32,
+                                "max_batch_size": 64,
+                                "num_forward_passes": 3,
+                                "mode": "largest",
+                            }
+                        }
                     ),
                     ),
                     OmegaConf.create({"PreLaunchTrainBatchSizeVerificationCallback": {"batch_size": 64}}),
                     OmegaConf.create({"PreLaunchTrainBatchSizeVerificationCallback": {"batch_size": 64}}),
                     OmegaConf.create(
                     OmegaConf.create(
Discard