Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#572 new generated docs

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-000_new_generated_docs
84 changed files with 6562 additions and 14256 deletions
  1. 36
    55
      docs/_modules/index.html
  2. 0
    999
      docs/_modules/logging.html
  3. 0
    144
      docs/_modules/super_gradients/common/abstractions/abstract_logger.html
  4. 136
    120
      docs/_modules/super_gradients/common/auto_logging/auto_logger.html
  5. 294
    0
      docs/_modules/super_gradients/common/auto_logging/console_logging.html
  6. 28
    26
      docs/_modules/super_gradients/common/aws_connection/aws_connector.html
  7. 0
    248
      docs/_modules/super_gradients/common/aws_connection/aws_secrets_manager_connector.html
  8. 133
    0
      docs/_modules/super_gradients/common/crash_handler/crash_handler.html
  9. 27
    25
      docs/_modules/super_gradients/common/data_connection/s3_connector.html
  10. 26
    24
      docs/_modules/super_gradients/common/data_interface/adnn_model_repository_data_interface.html
  11. 24
    22
      docs/_modules/super_gradients/common/data_interface/dataset_data_interface.html
  12. 24
    22
      docs/_modules/super_gradients/common/data_types/enum/deep_learning_task.html
  13. 25
    23
      docs/_modules/super_gradients/common/data_types/enum/evaluation_type.html
  14. 45
    26
      docs/_modules/super_gradients/common/data_types/enum/multi_gpu_mode.html
  15. 31
    28
      docs/_modules/super_gradients/common/data_types/enum/strict_load.html
  16. 128
    0
      docs/_modules/super_gradients/common/data_types/enum/upsample_mode.html
  17. 0
    206
      docs/_modules/super_gradients/common/decorators/deci_logger.html
  18. 24
    22
      docs/_modules/super_gradients/common/decorators/explicit_params_validator.html
  19. 26
    24
      docs/_modules/super_gradients/common/decorators/singleton.html
  20. 145
    54
      docs/_modules/super_gradients/common/environment/env_helpers.html
  21. 373
    0
      docs/_modules/super_gradients/common/object_names.html
  22. 785
    0
      docs/_modules/super_gradients/training/dataloaders/dataloaders.html
  23. 0
    207
      docs/_modules/super_gradients/training/datasets/all_datasets.html
  24. 0
    558
      docs/_modules/super_gradients/training/datasets/auto_augment.html
  25. 194
    0
      docs/_modules/super_gradients/training/datasets/classification_datasets/cifar.html
  26. 139
    0
      docs/_modules/super_gradients/training/datasets/classification_datasets/imagenet_dataset.html
  27. 31
    29
      docs/_modules/super_gradients/training/datasets/data_augmentation.html
  28. 0
    959
      docs/_modules/super_gradients/training/datasets/dataset_interfaces/dataset_interface.html
  29. 0
    811
      docs/_modules/super_gradients/training/datasets/datasets_utils.html
  30. 126
    273
      docs/_modules/super_gradients/training/datasets/detection_datasets/coco_detection.html
  31. 99
    91
      docs/_modules/super_gradients/training/datasets/detection_datasets/detection_dataset.html
  32. 109
    59
      docs/_modules/super_gradients/training/datasets/detection_datasets/pascal_voc_detection.html
  33. 0
    422
      docs/_modules/super_gradients/training/datasets/mixup.html
  34. 0
    222
      docs/_modules/super_gradients/training/datasets/segmentation_datasets/cityscape_segmentation.html
  35. 36
    49
      docs/_modules/super_gradients/training/datasets/segmentation_datasets/coco_segmentation.html
  36. 0
    155
      docs/_modules/super_gradients/training/datasets/segmentation_datasets/pascal_aug_segmentation.html
  37. 128
    34
      docs/_modules/super_gradients/training/datasets/segmentation_datasets/pascal_voc_segmentation.html
  38. 42
    141
      docs/_modules/super_gradients/training/datasets/segmentation_datasets/segmentation_dataset.html
  39. 26
    23
      docs/_modules/super_gradients/training/datasets/segmentation_datasets/supervisely_persons_segmentation.html
  40. 30
    28
      docs/_modules/super_gradients/training/datasets/sg_dataset.html
  41. 0
    158
      docs/_modules/super_gradients/training/exceptions/dataset_exceptions.html
  42. 0
    150
      docs/_modules/super_gradients/training/exceptions/sg_model_exceptions.html
  43. 115
    95
      docs/_modules/super_gradients/training/kd_trainer/kd_trainer.html
  44. 0
    251
      docs/_modules/super_gradients/training/legacy/utils.html
  45. 25
    23
      docs/_modules/super_gradients/training/losses/bce_dice_loss.html
  46. 0
    166
      docs/_modules/super_gradients/training/losses/ddrnet_loss.html
  47. 40
    23
      docs/_modules/super_gradients/training/losses/dice_ce_edge_loss.html
  48. 25
    23
      docs/_modules/super_gradients/training/losses/focal_loss.html
  49. 34
    23
      docs/_modules/super_gradients/training/losses/kd_losses.html
  50. 29
    27
      docs/_modules/super_gradients/training/losses/label_smoothing_cross_entropy_loss.html
  51. 0
    227
      docs/_modules/super_gradients/training/losses/ohem_ce_loss.html
  52. 25
    23
      docs/_modules/super_gradients/training/losses/r_squared_loss.html
  53. 35
    24
      docs/_modules/super_gradients/training/losses/shelfnet_ohem_loss.html
  54. 35
    24
      docs/_modules/super_gradients/training/losses/shelfnet_semantic_encoding_loss.html
  55. 38
    27
      docs/_modules/super_gradients/training/losses/ssd_loss.html
  56. 0
    180
      docs/_modules/super_gradients/training/losses/yolo_v3_loss.html
  57. 0
    340
      docs/_modules/super_gradients/training/losses/yolo_v5_loss.html
  58. 531
    75
      docs/_modules/super_gradients/training/losses/yolox_loss.html
  59. 36
    31
      docs/_modules/super_gradients/training/metrics/classification_metrics.html
  60. 160
    53
      docs/_modules/super_gradients/training/metrics/detection_metrics.html
  61. 0
    226
      docs/_modules/super_gradients/training/metrics/metric_utils.html
  62. 136
    105
      docs/_modules/super_gradients/training/metrics/segmentation_metrics.html
  63. 0
    184
      docs/_modules/super_gradients/training/models/sg_module.html
  64. 0
    187
      docs/_modules/super_gradients/training/params.html
  65. 423
    249
      docs/_modules/super_gradients/training/sg_trainer/sg_trainer.html
  66. 231
    0
      docs/_modules/super_gradients/training/training_hyperparams/training_hyperparams.html
  67. 999
    0
      docs/_modules/super_gradients/training/transforms/transforms.html
  68. 0
    885
      docs/_modules/super_gradients/training/utils/callbacks.html
  69. 147
    102
      docs/_modules/super_gradients/training/utils/checkpoint_utils.html
  70. 0
    999
      docs/_modules/super_gradients/training/utils/detection_utils.html
  71. 0
    273
      docs/_modules/super_gradients/training/utils/distributed_training_utils.html
  72. 0
    268
      docs/_modules/super_gradients/training/utils/early_stopping.html
  73. 0
    245
      docs/_modules/super_gradients/training/utils/ema.html
  74. 0
    143
      docs/_modules/super_gradients/training/utils/export_utils.html
  75. 0
    339
      docs/_modules/super_gradients/training/utils/module_utils.html
  76. 0
    230
      docs/_modules/super_gradients/training/utils/optimizer_utils.html
  77. 0
    247
      docs/_modules/super_gradients/training/utils/optimizers/rmsprop_tf.html
  78. 0
    143
      docs/_modules/super_gradients/training/utils/regularization_utils.html
  79. 0
    321
      docs/_modules/super_gradients/training/utils/segmentation_utils.html
  80. 0
    449
      docs/_modules/super_gradients/training/utils/sg_model_utils.html
  81. 0
    265
      docs/_modules/super_gradients/training/utils/ssd_utils.html
  82. 98
    71
      docs/_modules/super_gradients/training/utils/utils.html
  83. 130
    0
      docs/_modules/super_gradients/training/utils/version_utils.html
  84. 0
    253
      docs/_modules/super_gradients/training/utils/weight_averaging_utils.html
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>Overview: module code &mdash; SuperGradients 1.0 documentation</title>
+  <title>Overview: module code &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../_static/js/html5shiv.min.js"></script>
     <script src="../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
         <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
         <script src="../_static/jquery.js"></script>
         <script src="../_static/jquery.js"></script>
         <script src="../_static/underscore.js"></script>
         <script src="../_static/underscore.js"></script>
+        <script src="../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../_static/doctools.js"></script>
         <script src="../_static/doctools.js"></script>
+        <script src="../_static/sphinx_highlight.js"></script>
     <script src="../_static/js/theme.js"></script>
     <script src="../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../genindex.html" />
     <link rel="index" title="Index" href="../genindex.html" />
     <link rel="search" title="Search" href="../search.html" /> 
     <link rel="search" title="Search" href="../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -85,10 +87,10 @@
            <div itemprop="articleBody">
            <div itemprop="articleBody">
              
              
   <h1>All modules for which code is available</h1>
   <h1>All modules for which code is available</h1>
-<ul><li><a href="super_gradients/common/abstractions/abstract_logger.html">super_gradients.common.abstractions.abstract_logger</a></li>
-<li><a href="super_gradients/common/auto_logging/auto_logger.html">super_gradients.common.auto_logging.auto_logger</a></li>
+<ul><li><a href="super_gradients/common/auto_logging/auto_logger.html">super_gradients.common.auto_logging.auto_logger</a></li>
+<li><a href="super_gradients/common/auto_logging/console_logging.html">super_gradients.common.auto_logging.console_logging</a></li>
 <li><a href="super_gradients/common/aws_connection/aws_connector.html">super_gradients.common.aws_connection.aws_connector</a></li>
 <li><a href="super_gradients/common/aws_connection/aws_connector.html">super_gradients.common.aws_connection.aws_connector</a></li>
-<li><a href="super_gradients/common/aws_connection/aws_secrets_manager_connector.html">super_gradients.common.aws_connection.aws_secrets_manager_connector</a></li>
+<li><a href="super_gradients/common/crash_handler/crash_handler.html">super_gradients.common.crash_handler.crash_handler</a></li>
 <li><a href="super_gradients/common/data_connection/s3_connector.html">super_gradients.common.data_connection.s3_connector</a></li>
 <li><a href="super_gradients/common/data_connection/s3_connector.html">super_gradients.common.data_connection.s3_connector</a></li>
 <li><a href="super_gradients/common/data_interface/adnn_model_repository_data_interface.html">super_gradients.common.data_interface.adnn_model_repository_data_interface</a></li>
 <li><a href="super_gradients/common/data_interface/adnn_model_repository_data_interface.html">super_gradients.common.data_interface.adnn_model_repository_data_interface</a></li>
 <li><a href="super_gradients/common/data_interface/dataset_data_interface.html">super_gradients.common.data_interface.dataset_data_interface</a></li>
 <li><a href="super_gradients/common/data_interface/dataset_data_interface.html">super_gradients.common.data_interface.dataset_data_interface</a></li>
@@ -96,37 +98,29 @@
 <li><a href="super_gradients/common/data_types/enum/evaluation_type.html">super_gradients.common.data_types.enum.evaluation_type</a></li>
 <li><a href="super_gradients/common/data_types/enum/evaluation_type.html">super_gradients.common.data_types.enum.evaluation_type</a></li>
 <li><a href="super_gradients/common/data_types/enum/multi_gpu_mode.html">super_gradients.common.data_types.enum.multi_gpu_mode</a></li>
 <li><a href="super_gradients/common/data_types/enum/multi_gpu_mode.html">super_gradients.common.data_types.enum.multi_gpu_mode</a></li>
 <li><a href="super_gradients/common/data_types/enum/strict_load.html">super_gradients.common.data_types.enum.strict_load</a></li>
 <li><a href="super_gradients/common/data_types/enum/strict_load.html">super_gradients.common.data_types.enum.strict_load</a></li>
-<li><a href="super_gradients/common/decorators/deci_logger.html">super_gradients.common.decorators.deci_logger</a></li>
+<li><a href="super_gradients/common/data_types/enum/upsample_mode.html">super_gradients.common.data_types.enum.upsample_mode</a></li>
 <li><a href="super_gradients/common/decorators/explicit_params_validator.html">super_gradients.common.decorators.explicit_params_validator</a></li>
 <li><a href="super_gradients/common/decorators/explicit_params_validator.html">super_gradients.common.decorators.explicit_params_validator</a></li>
 <li><a href="super_gradients/common/decorators/singleton.html">super_gradients.common.decorators.singleton</a></li>
 <li><a href="super_gradients/common/decorators/singleton.html">super_gradients.common.decorators.singleton</a></li>
 <li><a href="super_gradients/common/environment/env_helpers.html">super_gradients.common.environment.env_helpers</a></li>
 <li><a href="super_gradients/common/environment/env_helpers.html">super_gradients.common.environment.env_helpers</a></li>
-<li><a href="super_gradients/training/datasets/all_datasets.html">super_gradients.training.datasets.all_datasets</a></li>
-<li><a href="super_gradients/training/datasets/auto_augment.html">super_gradients.training.datasets.auto_augment</a></li>
+<li><a href="super_gradients/common/object_names.html">super_gradients.common.object_names</a></li>
+<li><a href="super_gradients/training/dataloaders/dataloaders.html">super_gradients.training.dataloaders.dataloaders</a></li>
+<li><a href="super_gradients/training/datasets/classification_datasets/cifar.html">super_gradients.training.datasets.classification_datasets.cifar</a></li>
+<li><a href="super_gradients/training/datasets/classification_datasets/imagenet_dataset.html">super_gradients.training.datasets.classification_datasets.imagenet_dataset</a></li>
 <li><a href="super_gradients/training/datasets/data_augmentation.html">super_gradients.training.datasets.data_augmentation</a></li>
 <li><a href="super_gradients/training/datasets/data_augmentation.html">super_gradients.training.datasets.data_augmentation</a></li>
-<li><a href="super_gradients/training/datasets/dataset_interfaces/dataset_interface.html">super_gradients.training.datasets.dataset_interfaces.dataset_interface</a></li>
-<li><a href="super_gradients/training/datasets/datasets_utils.html">super_gradients.training.datasets.datasets_utils</a></li>
 <li><a href="super_gradients/training/datasets/detection_datasets/coco_detection.html">super_gradients.training.datasets.detection_datasets.coco_detection</a></li>
 <li><a href="super_gradients/training/datasets/detection_datasets/coco_detection.html">super_gradients.training.datasets.detection_datasets.coco_detection</a></li>
 <li><a href="super_gradients/training/datasets/detection_datasets/detection_dataset.html">super_gradients.training.datasets.detection_datasets.detection_dataset</a></li>
 <li><a href="super_gradients/training/datasets/detection_datasets/detection_dataset.html">super_gradients.training.datasets.detection_datasets.detection_dataset</a></li>
 <li><a href="super_gradients/training/datasets/detection_datasets/pascal_voc_detection.html">super_gradients.training.datasets.detection_datasets.pascal_voc_detection</a></li>
 <li><a href="super_gradients/training/datasets/detection_datasets/pascal_voc_detection.html">super_gradients.training.datasets.detection_datasets.pascal_voc_detection</a></li>
-<li><a href="super_gradients/training/datasets/mixup.html">super_gradients.training.datasets.mixup</a></li>
-<li><a href="super_gradients/training/datasets/segmentation_datasets/cityscape_segmentation.html">super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation</a></li>
 <li><a href="super_gradients/training/datasets/segmentation_datasets/coco_segmentation.html">super_gradients.training.datasets.segmentation_datasets.coco_segmentation</a></li>
 <li><a href="super_gradients/training/datasets/segmentation_datasets/coco_segmentation.html">super_gradients.training.datasets.segmentation_datasets.coco_segmentation</a></li>
-<li><a href="super_gradients/training/datasets/segmentation_datasets/pascal_aug_segmentation.html">super_gradients.training.datasets.segmentation_datasets.pascal_aug_segmentation</a></li>
 <li><a href="super_gradients/training/datasets/segmentation_datasets/pascal_voc_segmentation.html">super_gradients.training.datasets.segmentation_datasets.pascal_voc_segmentation</a></li>
 <li><a href="super_gradients/training/datasets/segmentation_datasets/pascal_voc_segmentation.html">super_gradients.training.datasets.segmentation_datasets.pascal_voc_segmentation</a></li>
 <li><a href="super_gradients/training/datasets/segmentation_datasets/segmentation_dataset.html">super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</a></li>
 <li><a href="super_gradients/training/datasets/segmentation_datasets/segmentation_dataset.html">super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</a></li>
 <li><a href="super_gradients/training/datasets/segmentation_datasets/supervisely_persons_segmentation.html">super_gradients.training.datasets.segmentation_datasets.supervisely_persons_segmentation</a></li>
 <li><a href="super_gradients/training/datasets/segmentation_datasets/supervisely_persons_segmentation.html">super_gradients.training.datasets.segmentation_datasets.supervisely_persons_segmentation</a></li>
 <li><a href="super_gradients/training/datasets/sg_dataset.html">super_gradients.training.datasets.sg_dataset</a></li>
 <li><a href="super_gradients/training/datasets/sg_dataset.html">super_gradients.training.datasets.sg_dataset</a></li>
-<li><a href="super_gradients/training/exceptions/dataset_exceptions.html">super_gradients.training.exceptions.dataset_exceptions</a></li>
-<li><a href="super_gradients/training/exceptions/sg_model_exceptions.html">super_gradients.training.exceptions.sg_model_exceptions</a></li>
-<li><a href="super_gradients/training/kd_model/kd_model.html">super_gradients.training.kd_model.kd_model</a></li>
-<li><a href="super_gradients/training/legacy/utils.html">super_gradients.training.legacy.utils</a></li>
+<li><a href="super_gradients/training/kd_trainer/kd_trainer.html">super_gradients.training.kd_trainer.kd_trainer</a></li>
 <li><a href="super_gradients/training/losses/bce_dice_loss.html">super_gradients.training.losses.bce_dice_loss</a></li>
 <li><a href="super_gradients/training/losses/bce_dice_loss.html">super_gradients.training.losses.bce_dice_loss</a></li>
-<li><a href="super_gradients/training/losses/ddrnet_loss.html">super_gradients.training.losses.ddrnet_loss</a></li>
 <li><a href="super_gradients/training/losses/dice_ce_edge_loss.html">super_gradients.training.losses.dice_ce_edge_loss</a></li>
 <li><a href="super_gradients/training/losses/dice_ce_edge_loss.html">super_gradients.training.losses.dice_ce_edge_loss</a></li>
 <li><a href="super_gradients/training/losses/focal_loss.html">super_gradients.training.losses.focal_loss</a></li>
 <li><a href="super_gradients/training/losses/focal_loss.html">super_gradients.training.losses.focal_loss</a></li>
 <li><a href="super_gradients/training/losses/kd_losses.html">super_gradients.training.losses.kd_losses</a></li>
 <li><a href="super_gradients/training/losses/kd_losses.html">super_gradients.training.losses.kd_losses</a></li>
 <li><a href="super_gradients/training/losses/label_smoothing_cross_entropy_loss.html">super_gradients.training.losses.label_smoothing_cross_entropy_loss</a></li>
 <li><a href="super_gradients/training/losses/label_smoothing_cross_entropy_loss.html">super_gradients.training.losses.label_smoothing_cross_entropy_loss</a></li>
-<li><a href="super_gradients/training/losses/ohem_ce_loss.html">super_gradients.training.losses.ohem_ce_loss</a></li>
 <li><a href="super_gradients/training/losses/r_squared_loss.html">super_gradients.training.losses.r_squared_loss</a></li>
 <li><a href="super_gradients/training/losses/r_squared_loss.html">super_gradients.training.losses.r_squared_loss</a></li>
 <li><a href="super_gradients/training/losses/shelfnet_ohem_loss.html">super_gradients.training.losses.shelfnet_ohem_loss</a></li>
 <li><a href="super_gradients/training/losses/shelfnet_ohem_loss.html">super_gradients.training.losses.shelfnet_ohem_loss</a></li>
 <li><a href="super_gradients/training/losses/shelfnet_semantic_encoding_loss.html">super_gradients.training.losses.shelfnet_semantic_encoding_loss</a></li>
 <li><a href="super_gradients/training/losses/shelfnet_semantic_encoding_loss.html">super_gradients.training.losses.shelfnet_semantic_encoding_loss</a></li>
@@ -134,26 +128,13 @@
 <li><a href="super_gradients/training/losses/yolox_loss.html">super_gradients.training.losses.yolox_loss</a></li>
 <li><a href="super_gradients/training/losses/yolox_loss.html">super_gradients.training.losses.yolox_loss</a></li>
 <li><a href="super_gradients/training/metrics/classification_metrics.html">super_gradients.training.metrics.classification_metrics</a></li>
 <li><a href="super_gradients/training/metrics/classification_metrics.html">super_gradients.training.metrics.classification_metrics</a></li>
 <li><a href="super_gradients/training/metrics/detection_metrics.html">super_gradients.training.metrics.detection_metrics</a></li>
 <li><a href="super_gradients/training/metrics/detection_metrics.html">super_gradients.training.metrics.detection_metrics</a></li>
-<li><a href="super_gradients/training/metrics/metric_utils.html">super_gradients.training.metrics.metric_utils</a></li>
 <li><a href="super_gradients/training/metrics/segmentation_metrics.html">super_gradients.training.metrics.segmentation_metrics</a></li>
 <li><a href="super_gradients/training/metrics/segmentation_metrics.html">super_gradients.training.metrics.segmentation_metrics</a></li>
-<li><a href="super_gradients/training/models/sg_module.html">super_gradients.training.models.sg_module</a></li>
-<li><a href="super_gradients/training/sg_model/sg_model.html">super_gradients.training.sg_model.sg_model</a></li>
-<li><a href="super_gradients/training/utils/callbacks.html">super_gradients.training.utils.callbacks</a></li>
+<li><a href="super_gradients/training/sg_trainer/sg_trainer.html">super_gradients.training.sg_trainer.sg_trainer</a></li>
+<li><a href="super_gradients/training/training_hyperparams/training_hyperparams.html">super_gradients.training.training_hyperparams.training_hyperparams</a></li>
+<li><a href="super_gradients/training/transforms/transforms.html">super_gradients.training.transforms.transforms</a></li>
 <li><a href="super_gradients/training/utils/checkpoint_utils.html">super_gradients.training.utils.checkpoint_utils</a></li>
 <li><a href="super_gradients/training/utils/checkpoint_utils.html">super_gradients.training.utils.checkpoint_utils</a></li>
-<li><a href="super_gradients/training/utils/detection_utils.html">super_gradients.training.utils.detection_utils</a></li>
-<li><a href="super_gradients/training/utils/distributed_training_utils.html">super_gradients.training.utils.distributed_training_utils</a></li>
-<li><a href="super_gradients/training/utils/early_stopping.html">super_gradients.training.utils.early_stopping</a></li>
-<li><a href="super_gradients/training/utils/ema.html">super_gradients.training.utils.ema</a></li>
-<li><a href="super_gradients/training/utils/export_utils.html">super_gradients.training.utils.export_utils</a></li>
-<li><a href="super_gradients/training/utils/module_utils.html">super_gradients.training.utils.module_utils</a></li>
-<li><a href="super_gradients/training/utils/optimizer_utils.html">super_gradients.training.utils.optimizer_utils</a></li>
-<li><a href="super_gradients/training/utils/optimizers/rmsprop_tf.html">super_gradients.training.utils.optimizers.rmsprop_tf</a></li>
-<li><a href="super_gradients/training/utils/regularization_utils.html">super_gradients.training.utils.regularization_utils</a></li>
-<li><a href="super_gradients/training/utils/segmentation_utils.html">super_gradients.training.utils.segmentation_utils</a></li>
-<li><a href="super_gradients/training/utils/sg_model_utils.html">super_gradients.training.utils.sg_model_utils</a></li>
-<li><a href="super_gradients/training/utils/ssd_utils.html">super_gradients.training.utils.ssd_utils</a></li>
 <li><a href="super_gradients/training/utils/utils.html">super_gradients.training.utils.utils</a></li>
 <li><a href="super_gradients/training/utils/utils.html">super_gradients.training.utils.utils</a></li>
-<li><a href="super_gradients/training/utils/weight_averaging_utils.html">super_gradients.training.utils.weight_averaging_utils</a></li>
+<li><a href="super_gradients/training/utils/version_utils.html">super_gradients.training.utils.version_utils</a></li>
 </ul>
 </ul>
 
 
            </div>
            </div>
@@ -183,4 +164,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>logging &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
  14. <script src="../_static/jquery.js"></script>
  15. <script src="../_static/underscore.js"></script>
  16. <script src="../_static/doctools.js"></script>
  17. <script src="../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../genindex.html" />
  19. <link rel="search" title="Search" href="../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../welcome.html">SuperGradients</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../super_gradients.common.html">Common</a></li>
  40. <li class="toctree-l1"><a class="reference internal" href="../super_gradients.training.html">Training</a></li>
  41. </ul>
  42. </div>
  43. </div>
  44. </nav>
  45. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  46. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  47. <a href="../index.html">SuperGradients</a>
  48. </nav>
  49. <div class="wy-nav-content">
  50. <div class="rst-content">
  51. <div role="navigation" aria-label="Page navigation">
  52. <ul class="wy-breadcrumbs">
  53. <li><a href="../index.html" class="icon icon-home"></a> &raquo;</li>
  54. <li><a href="index.html">Module code</a> &raquo;</li>
  55. <li>logging</li>
  56. <li class="wy-breadcrumbs-aside">
  57. </li>
  58. </ul>
  59. <hr/>
  60. </div>
  61. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  62. <div itemprop="articleBody">
  63. <h1>Source code for logging</h1><div class="highlight"><pre>
  64. <span></span><span class="c1"># Copyright 2001-2019 by Vinay Sajip. All Rights Reserved.</span>
  65. <span class="c1">#</span>
  66. <span class="c1"># Permission to use, copy, modify, and distribute this software and its</span>
  67. <span class="c1"># documentation for any purpose and without fee is hereby granted,</span>
  68. <span class="c1"># provided that the above copyright notice appear in all copies and that</span>
  69. <span class="c1"># both that copyright notice and this permission notice appear in</span>
  70. <span class="c1"># supporting documentation, and that the name of Vinay Sajip</span>
  71. <span class="c1"># not be used in advertising or publicity pertaining to distribution</span>
  72. <span class="c1"># of the software without specific, written prior permission.</span>
  73. <span class="c1"># VINAY SAJIP DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING</span>
  74. <span class="c1"># ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL</span>
  75. <span class="c1"># VINAY SAJIP BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR</span>
  76. <span class="c1"># ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER</span>
  77. <span class="c1"># IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT</span>
  78. <span class="c1"># OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.</span>
  79. <span class="sd">&quot;&quot;&quot;</span>
  80. <span class="sd">Logging package for Python. Based on PEP 282 and comments thereto in</span>
  81. <span class="sd">comp.lang.python.</span>
  82. <span class="sd">Copyright (C) 2001-2019 Vinay Sajip. All Rights Reserved.</span>
  83. <span class="sd">To use, simply &#39;import logging&#39; and log away!</span>
  84. <span class="sd">&quot;&quot;&quot;</span>
  85. <span class="kn">import</span> <span class="nn">sys</span><span class="o">,</span> <span class="nn">os</span><span class="o">,</span> <span class="nn">time</span><span class="o">,</span> <span class="nn">io</span><span class="o">,</span> <span class="nn">re</span><span class="o">,</span> <span class="nn">traceback</span><span class="o">,</span> <span class="nn">warnings</span><span class="o">,</span> <span class="nn">weakref</span><span class="o">,</span> <span class="nn">collections.abc</span>
  86. <span class="kn">from</span> <span class="nn">string</span> <span class="kn">import</span> <span class="n">Template</span>
  87. <span class="kn">from</span> <span class="nn">string</span> <span class="kn">import</span> <span class="n">Formatter</span> <span class="k">as</span> <span class="n">StrFormatter</span>
  88. <span class="n">__all__</span> <span class="o">=</span> <span class="p">[</span><span class="s1">&#39;BASIC_FORMAT&#39;</span><span class="p">,</span> <span class="s1">&#39;BufferingFormatter&#39;</span><span class="p">,</span> <span class="s1">&#39;CRITICAL&#39;</span><span class="p">,</span> <span class="s1">&#39;DEBUG&#39;</span><span class="p">,</span> <span class="s1">&#39;ERROR&#39;</span><span class="p">,</span>
  89. <span class="s1">&#39;FATAL&#39;</span><span class="p">,</span> <span class="s1">&#39;FileHandler&#39;</span><span class="p">,</span> <span class="s1">&#39;Filter&#39;</span><span class="p">,</span> <span class="s1">&#39;Formatter&#39;</span><span class="p">,</span> <span class="s1">&#39;Handler&#39;</span><span class="p">,</span> <span class="s1">&#39;INFO&#39;</span><span class="p">,</span>
  90. <span class="s1">&#39;LogRecord&#39;</span><span class="p">,</span> <span class="s1">&#39;Logger&#39;</span><span class="p">,</span> <span class="s1">&#39;LoggerAdapter&#39;</span><span class="p">,</span> <span class="s1">&#39;NOTSET&#39;</span><span class="p">,</span> <span class="s1">&#39;NullHandler&#39;</span><span class="p">,</span>
  91. <span class="s1">&#39;StreamHandler&#39;</span><span class="p">,</span> <span class="s1">&#39;WARN&#39;</span><span class="p">,</span> <span class="s1">&#39;WARNING&#39;</span><span class="p">,</span> <span class="s1">&#39;addLevelName&#39;</span><span class="p">,</span> <span class="s1">&#39;basicConfig&#39;</span><span class="p">,</span>
  92. <span class="s1">&#39;captureWarnings&#39;</span><span class="p">,</span> <span class="s1">&#39;critical&#39;</span><span class="p">,</span> <span class="s1">&#39;debug&#39;</span><span class="p">,</span> <span class="s1">&#39;disable&#39;</span><span class="p">,</span> <span class="s1">&#39;error&#39;</span><span class="p">,</span>
  93. <span class="s1">&#39;exception&#39;</span><span class="p">,</span> <span class="s1">&#39;fatal&#39;</span><span class="p">,</span> <span class="s1">&#39;getLevelName&#39;</span><span class="p">,</span> <span class="s1">&#39;getLogger&#39;</span><span class="p">,</span> <span class="s1">&#39;getLoggerClass&#39;</span><span class="p">,</span>
  94. <span class="s1">&#39;info&#39;</span><span class="p">,</span> <span class="s1">&#39;log&#39;</span><span class="p">,</span> <span class="s1">&#39;makeLogRecord&#39;</span><span class="p">,</span> <span class="s1">&#39;setLoggerClass&#39;</span><span class="p">,</span> <span class="s1">&#39;shutdown&#39;</span><span class="p">,</span>
  95. <span class="s1">&#39;warn&#39;</span><span class="p">,</span> <span class="s1">&#39;warning&#39;</span><span class="p">,</span> <span class="s1">&#39;getLogRecordFactory&#39;</span><span class="p">,</span> <span class="s1">&#39;setLogRecordFactory&#39;</span><span class="p">,</span>
  96. <span class="s1">&#39;lastResort&#39;</span><span class="p">,</span> <span class="s1">&#39;raiseExceptions&#39;</span><span class="p">]</span>
  97. <span class="kn">import</span> <span class="nn">threading</span>
  98. <span class="n">__author__</span> <span class="o">=</span> <span class="s2">&quot;Vinay Sajip &lt;vinay_sajip@red-dove.com&gt;&quot;</span>
  99. <span class="n">__status__</span> <span class="o">=</span> <span class="s2">&quot;production&quot;</span>
  100. <span class="c1"># The following module attributes are no longer updated.</span>
  101. <span class="n">__version__</span> <span class="o">=</span> <span class="s2">&quot;0.5.1.2&quot;</span>
  102. <span class="n">__date__</span> <span class="o">=</span> <span class="s2">&quot;07 February 2010&quot;</span>
  103. <span class="c1">#---------------------------------------------------------------------------</span>
  104. <span class="c1"># Miscellaneous module data</span>
  105. <span class="c1">#---------------------------------------------------------------------------</span>
  106. <span class="c1">#</span>
  107. <span class="c1">#_startTime is used as the base when calculating the relative time of events</span>
  108. <span class="c1">#</span>
  109. <span class="n">_startTime</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
  110. <span class="c1">#</span>
  111. <span class="c1">#raiseExceptions is used to see if exceptions during handling should be</span>
  112. <span class="c1">#propagated</span>
  113. <span class="c1">#</span>
  114. <span class="n">raiseExceptions</span> <span class="o">=</span> <span class="kc">True</span>
  115. <span class="c1">#</span>
  116. <span class="c1"># If you don&#39;t want threading information in the log, set this to zero</span>
  117. <span class="c1">#</span>
  118. <span class="n">logThreads</span> <span class="o">=</span> <span class="kc">True</span>
  119. <span class="c1">#</span>
  120. <span class="c1"># If you don&#39;t want multiprocessing information in the log, set this to zero</span>
  121. <span class="c1">#</span>
  122. <span class="n">logMultiprocessing</span> <span class="o">=</span> <span class="kc">True</span>
  123. <span class="c1">#</span>
  124. <span class="c1"># If you don&#39;t want process information in the log, set this to zero</span>
  125. <span class="c1">#</span>
  126. <span class="n">logProcesses</span> <span class="o">=</span> <span class="kc">True</span>
  127. <span class="c1">#---------------------------------------------------------------------------</span>
  128. <span class="c1"># Level related stuff</span>
  129. <span class="c1">#---------------------------------------------------------------------------</span>
  130. <span class="c1">#</span>
  131. <span class="c1"># Default levels and level names, these can be replaced with any positive set</span>
  132. <span class="c1"># of values having corresponding names. There is a pseudo-level, NOTSET, which</span>
  133. <span class="c1"># is only really there as a lower limit for user-defined levels. Handlers and</span>
  134. <span class="c1"># loggers are initialized with NOTSET so that they will log all messages, even</span>
  135. <span class="c1"># at user-defined levels.</span>
  136. <span class="c1">#</span>
  137. <span class="n">CRITICAL</span> <span class="o">=</span> <span class="mi">50</span>
  138. <span class="n">FATAL</span> <span class="o">=</span> <span class="n">CRITICAL</span>
  139. <span class="n">ERROR</span> <span class="o">=</span> <span class="mi">40</span>
  140. <span class="n">WARNING</span> <span class="o">=</span> <span class="mi">30</span>
  141. <span class="n">WARN</span> <span class="o">=</span> <span class="n">WARNING</span>
  142. <span class="n">INFO</span> <span class="o">=</span> <span class="mi">20</span>
  143. <span class="n">DEBUG</span> <span class="o">=</span> <span class="mi">10</span>
  144. <span class="n">NOTSET</span> <span class="o">=</span> <span class="mi">0</span>
  145. <span class="n">_levelToName</span> <span class="o">=</span> <span class="p">{</span>
  146. <span class="n">CRITICAL</span><span class="p">:</span> <span class="s1">&#39;CRITICAL&#39;</span><span class="p">,</span>
  147. <span class="n">ERROR</span><span class="p">:</span> <span class="s1">&#39;ERROR&#39;</span><span class="p">,</span>
  148. <span class="n">WARNING</span><span class="p">:</span> <span class="s1">&#39;WARNING&#39;</span><span class="p">,</span>
  149. <span class="n">INFO</span><span class="p">:</span> <span class="s1">&#39;INFO&#39;</span><span class="p">,</span>
  150. <span class="n">DEBUG</span><span class="p">:</span> <span class="s1">&#39;DEBUG&#39;</span><span class="p">,</span>
  151. <span class="n">NOTSET</span><span class="p">:</span> <span class="s1">&#39;NOTSET&#39;</span><span class="p">,</span>
  152. <span class="p">}</span>
  153. <span class="n">_nameToLevel</span> <span class="o">=</span> <span class="p">{</span>
  154. <span class="s1">&#39;CRITICAL&#39;</span><span class="p">:</span> <span class="n">CRITICAL</span><span class="p">,</span>
  155. <span class="s1">&#39;FATAL&#39;</span><span class="p">:</span> <span class="n">FATAL</span><span class="p">,</span>
  156. <span class="s1">&#39;ERROR&#39;</span><span class="p">:</span> <span class="n">ERROR</span><span class="p">,</span>
  157. <span class="s1">&#39;WARN&#39;</span><span class="p">:</span> <span class="n">WARNING</span><span class="p">,</span>
  158. <span class="s1">&#39;WARNING&#39;</span><span class="p">:</span> <span class="n">WARNING</span><span class="p">,</span>
  159. <span class="s1">&#39;INFO&#39;</span><span class="p">:</span> <span class="n">INFO</span><span class="p">,</span>
  160. <span class="s1">&#39;DEBUG&#39;</span><span class="p">:</span> <span class="n">DEBUG</span><span class="p">,</span>
  161. <span class="s1">&#39;NOTSET&#39;</span><span class="p">:</span> <span class="n">NOTSET</span><span class="p">,</span>
  162. <span class="p">}</span>
  163. <span class="k">def</span> <span class="nf">getLevelName</span><span class="p">(</span><span class="n">level</span><span class="p">):</span>
  164. <span class="sd">&quot;&quot;&quot;</span>
  165. <span class="sd"> Return the textual or numeric representation of logging level &#39;level&#39;.</span>
  166. <span class="sd"> If the level is one of the predefined levels (CRITICAL, ERROR, WARNING,</span>
  167. <span class="sd"> INFO, DEBUG) then you get the corresponding string. If you have</span>
  168. <span class="sd"> associated levels with names using addLevelName then the name you have</span>
  169. <span class="sd"> associated with &#39;level&#39; is returned.</span>
  170. <span class="sd"> If a numeric value corresponding to one of the defined levels is passed</span>
  171. <span class="sd"> in, the corresponding string representation is returned.</span>
  172. <span class="sd"> If a string representation of the level is passed in, the corresponding</span>
  173. <span class="sd"> numeric value is returned.</span>
  174. <span class="sd"> If no matching numeric or string value is passed in, the string</span>
  175. <span class="sd"> &#39;Level %s&#39; % level is returned.</span>
  176. <span class="sd"> &quot;&quot;&quot;</span>
  177. <span class="c1"># See Issues #22386, #27937 and #29220 for why it&#39;s this way</span>
  178. <span class="n">result</span> <span class="o">=</span> <span class="n">_levelToName</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
  179. <span class="k">if</span> <span class="n">result</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  180. <span class="k">return</span> <span class="n">result</span>
  181. <span class="n">result</span> <span class="o">=</span> <span class="n">_nameToLevel</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
  182. <span class="k">if</span> <span class="n">result</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  183. <span class="k">return</span> <span class="n">result</span>
  184. <span class="k">return</span> <span class="s2">&quot;Level </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="n">level</span>
  185. <span class="k">def</span> <span class="nf">addLevelName</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">levelName</span><span class="p">):</span>
  186. <span class="sd">&quot;&quot;&quot;</span>
  187. <span class="sd"> Associate &#39;levelName&#39; with &#39;level&#39;.</span>
  188. <span class="sd"> This is used when converting levels to text during message formatting.</span>
  189. <span class="sd"> &quot;&quot;&quot;</span>
  190. <span class="n">_acquireLock</span><span class="p">()</span>
  191. <span class="k">try</span><span class="p">:</span> <span class="c1">#unlikely to cause an exception, but you never know...</span>
  192. <span class="n">_levelToName</span><span class="p">[</span><span class="n">level</span><span class="p">]</span> <span class="o">=</span> <span class="n">levelName</span>
  193. <span class="n">_nameToLevel</span><span class="p">[</span><span class="n">levelName</span><span class="p">]</span> <span class="o">=</span> <span class="n">level</span>
  194. <span class="k">finally</span><span class="p">:</span>
  195. <span class="n">_releaseLock</span><span class="p">()</span>
  196. <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">sys</span><span class="p">,</span> <span class="s1">&#39;_getframe&#39;</span><span class="p">):</span>
  197. <span class="n">currentframe</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">sys</span><span class="o">.</span><span class="n">_getframe</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
  198. <span class="k">else</span><span class="p">:</span> <span class="c1">#pragma: no cover</span>
  199. <span class="k">def</span> <span class="nf">currentframe</span><span class="p">():</span>
  200. <span class="sd">&quot;&quot;&quot;Return the frame object for the caller&#39;s stack frame.&quot;&quot;&quot;</span>
  201. <span class="k">try</span><span class="p">:</span>
  202. <span class="k">raise</span> <span class="ne">Exception</span>
  203. <span class="k">except</span> <span class="ne">Exception</span><span class="p">:</span>
  204. <span class="k">return</span> <span class="n">sys</span><span class="o">.</span><span class="n">exc_info</span><span class="p">()[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">tb_frame</span><span class="o">.</span><span class="n">f_back</span>
  205. <span class="c1">#</span>
  206. <span class="c1"># _srcfile is used when walking the stack to check when we&#39;ve got the first</span>
  207. <span class="c1"># caller stack frame, by skipping frames whose filename is that of this</span>
  208. <span class="c1"># module&#39;s source. It therefore should contain the filename of this module&#39;s</span>
  209. <span class="c1"># source file.</span>
  210. <span class="c1">#</span>
  211. <span class="c1"># Ordinarily we would use __file__ for this, but frozen modules don&#39;t always</span>
  212. <span class="c1"># have __file__ set, for some reason (see Issue #21736). Thus, we get the</span>
  213. <span class="c1"># filename from a handy code object from a function defined in this module.</span>
  214. <span class="c1"># (There&#39;s no particular reason for picking addLevelName.)</span>
  215. <span class="c1">#</span>
  216. <span class="n">_srcfile</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">normcase</span><span class="p">(</span><span class="n">addLevelName</span><span class="o">.</span><span class="vm">__code__</span><span class="o">.</span><span class="n">co_filename</span><span class="p">)</span>
  217. <span class="c1"># _srcfile is only used in conjunction with sys._getframe().</span>
  218. <span class="c1"># To provide compatibility with older versions of Python, set _srcfile</span>
  219. <span class="c1"># to None if _getframe() is not available; this value will prevent</span>
  220. <span class="c1"># findCaller() from being called. You can also do this if you want to avoid</span>
  221. <span class="c1"># the overhead of fetching caller information, even when _getframe() is</span>
  222. <span class="c1"># available.</span>
  223. <span class="c1">#if not hasattr(sys, &#39;_getframe&#39;):</span>
  224. <span class="c1"># _srcfile = None</span>
  225. <span class="k">def</span> <span class="nf">_checkLevel</span><span class="p">(</span><span class="n">level</span><span class="p">):</span>
  226. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
  227. <span class="n">rv</span> <span class="o">=</span> <span class="n">level</span>
  228. <span class="k">elif</span> <span class="nb">str</span><span class="p">(</span><span class="n">level</span><span class="p">)</span> <span class="o">==</span> <span class="n">level</span><span class="p">:</span>
  229. <span class="k">if</span> <span class="n">level</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">_nameToLevel</span><span class="p">:</span>
  230. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Unknown level: </span><span class="si">%r</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="n">level</span><span class="p">)</span>
  231. <span class="n">rv</span> <span class="o">=</span> <span class="n">_nameToLevel</span><span class="p">[</span><span class="n">level</span><span class="p">]</span>
  232. <span class="k">else</span><span class="p">:</span>
  233. <span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;Level not an integer or a valid string: </span><span class="si">%r</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="n">level</span><span class="p">)</span>
  234. <span class="k">return</span> <span class="n">rv</span>
  235. <span class="c1">#---------------------------------------------------------------------------</span>
  236. <span class="c1"># Thread-related stuff</span>
  237. <span class="c1">#---------------------------------------------------------------------------</span>
  238. <span class="c1">#</span>
  239. <span class="c1">#_lock is used to serialize access to shared data structures in this module.</span>
  240. <span class="c1">#This needs to be an RLock because fileConfig() creates and configures</span>
  241. <span class="c1">#Handlers, and so might arbitrary user threads. Since Handler code updates the</span>
  242. <span class="c1">#shared dictionary _handlers, it needs to acquire the lock. But if configuring,</span>
  243. <span class="c1">#the lock would already have been acquired - so we need an RLock.</span>
  244. <span class="c1">#The same argument applies to Loggers and Manager.loggerDict.</span>
  245. <span class="c1">#</span>
  246. <span class="n">_lock</span> <span class="o">=</span> <span class="n">threading</span><span class="o">.</span><span class="n">RLock</span><span class="p">()</span>
  247. <span class="k">def</span> <span class="nf">_acquireLock</span><span class="p">():</span>
  248. <span class="sd">&quot;&quot;&quot;</span>
  249. <span class="sd"> Acquire the module-level lock for serializing access to shared data.</span>
  250. <span class="sd"> This should be released with _releaseLock().</span>
  251. <span class="sd"> &quot;&quot;&quot;</span>
  252. <span class="k">if</span> <span class="n">_lock</span><span class="p">:</span>
  253. <span class="n">_lock</span><span class="o">.</span><span class="n">acquire</span><span class="p">()</span>
  254. <span class="k">def</span> <span class="nf">_releaseLock</span><span class="p">():</span>
  255. <span class="sd">&quot;&quot;&quot;</span>
  256. <span class="sd"> Release the module-level lock acquired by calling _acquireLock().</span>
  257. <span class="sd"> &quot;&quot;&quot;</span>
  258. <span class="k">if</span> <span class="n">_lock</span><span class="p">:</span>
  259. <span class="n">_lock</span><span class="o">.</span><span class="n">release</span><span class="p">()</span>
  260. <span class="c1"># Prevent a held logging lock from blocking a child from logging.</span>
  261. <span class="k">if</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">os</span><span class="p">,</span> <span class="s1">&#39;register_at_fork&#39;</span><span class="p">):</span> <span class="c1"># Windows and friends.</span>
  262. <span class="k">def</span> <span class="nf">_register_at_fork_reinit_lock</span><span class="p">(</span><span class="n">instance</span><span class="p">):</span>
  263. <span class="k">pass</span> <span class="c1"># no-op when os.register_at_fork does not exist.</span>
  264. <span class="k">else</span><span class="p">:</span>
  265. <span class="c1"># A collection of instances with a _at_fork_reinit method (logging.Handler)</span>
  266. <span class="c1"># to be called in the child after forking. The weakref avoids us keeping</span>
  267. <span class="c1"># discarded Handler instances alive.</span>
  268. <span class="n">_at_fork_reinit_lock_weakset</span> <span class="o">=</span> <span class="n">weakref</span><span class="o">.</span><span class="n">WeakSet</span><span class="p">()</span>
  269. <span class="k">def</span> <span class="nf">_register_at_fork_reinit_lock</span><span class="p">(</span><span class="n">instance</span><span class="p">):</span>
  270. <span class="n">_acquireLock</span><span class="p">()</span>
  271. <span class="k">try</span><span class="p">:</span>
  272. <span class="n">_at_fork_reinit_lock_weakset</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">instance</span><span class="p">)</span>
  273. <span class="k">finally</span><span class="p">:</span>
  274. <span class="n">_releaseLock</span><span class="p">()</span>
  275. <span class="k">def</span> <span class="nf">_after_at_fork_child_reinit_locks</span><span class="p">():</span>
  276. <span class="k">for</span> <span class="n">handler</span> <span class="ow">in</span> <span class="n">_at_fork_reinit_lock_weakset</span><span class="p">:</span>
  277. <span class="n">handler</span><span class="o">.</span><span class="n">_at_fork_reinit</span><span class="p">()</span>
  278. <span class="c1"># _acquireLock() was called in the parent before forking.</span>
  279. <span class="c1"># The lock is reinitialized to unlocked state.</span>
  280. <span class="n">_lock</span><span class="o">.</span><span class="n">_at_fork_reinit</span><span class="p">()</span>
  281. <span class="n">os</span><span class="o">.</span><span class="n">register_at_fork</span><span class="p">(</span><span class="n">before</span><span class="o">=</span><span class="n">_acquireLock</span><span class="p">,</span>
  282. <span class="n">after_in_child</span><span class="o">=</span><span class="n">_after_at_fork_child_reinit_locks</span><span class="p">,</span>
  283. <span class="n">after_in_parent</span><span class="o">=</span><span class="n">_releaseLock</span><span class="p">)</span>
  284. <span class="c1">#---------------------------------------------------------------------------</span>
  285. <span class="c1"># The logging record</span>
  286. <span class="c1">#---------------------------------------------------------------------------</span>
  287. <span class="k">class</span> <span class="nc">LogRecord</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
  288. <span class="sd">&quot;&quot;&quot;</span>
  289. <span class="sd"> A LogRecord instance represents an event being logged.</span>
  290. <span class="sd"> LogRecord instances are created every time something is logged. They</span>
  291. <span class="sd"> contain all the information pertinent to the event being logged. The</span>
  292. <span class="sd"> main information passed in is in msg and args, which are combined</span>
  293. <span class="sd"> using str(msg) % args to create the message field of the record. The</span>
  294. <span class="sd"> record also includes information such as when the record was created,</span>
  295. <span class="sd"> the source line where the logging call was made, and any exception</span>
  296. <span class="sd"> information to be logged.</span>
  297. <span class="sd"> &quot;&quot;&quot;</span>
  298. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">level</span><span class="p">,</span> <span class="n">pathname</span><span class="p">,</span> <span class="n">lineno</span><span class="p">,</span>
  299. <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="p">,</span> <span class="n">func</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">sinfo</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  300. <span class="sd">&quot;&quot;&quot;</span>
  301. <span class="sd"> Initialize a logging record with interesting information.</span>
  302. <span class="sd"> &quot;&quot;&quot;</span>
  303. <span class="n">ct</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
  304. <span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">name</span>
  305. <span class="bp">self</span><span class="o">.</span><span class="n">msg</span> <span class="o">=</span> <span class="n">msg</span>
  306. <span class="c1">#</span>
  307. <span class="c1"># The following statement allows passing of a dictionary as a sole</span>
  308. <span class="c1"># argument, so that you can do something like</span>
  309. <span class="c1"># logging.debug(&quot;a %(a)d b %(b)s&quot;, {&#39;a&#39;:1, &#39;b&#39;:2})</span>
  310. <span class="c1"># Suggested by Stefan Behnel.</span>
  311. <span class="c1"># Note that without the test for args[0], we get a problem because</span>
  312. <span class="c1"># during formatting, we test to see if the arg is present using</span>
  313. <span class="c1"># &#39;if self.args:&#39;. If the event being logged is e.g. &#39;Value is %d&#39;</span>
  314. <span class="c1"># and if the passed arg fails &#39;if self.args:&#39; then no formatting</span>
  315. <span class="c1"># is done. For example, logger.warning(&#39;Value is %d&#39;, 0) would log</span>
  316. <span class="c1"># &#39;Value is %d&#39; instead of &#39;Value is 0&#39;.</span>
  317. <span class="c1"># For the use case of passing a dictionary, this should not be a</span>
  318. <span class="c1"># problem.</span>
  319. <span class="c1"># Issue #21172: a request was made to relax the isinstance check</span>
  320. <span class="c1"># to hasattr(args[0], &#39;__getitem__&#39;). However, the docs on string</span>
  321. <span class="c1"># formatting still seem to suggest a mapping object is required.</span>
  322. <span class="c1"># Thus, while not removing the isinstance check, it does now look</span>
  323. <span class="c1"># for collections.abc.Mapping rather than, as before, dict.</span>
  324. <span class="k">if</span> <span class="p">(</span><span class="n">args</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">args</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">args</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">collections</span><span class="o">.</span><span class="n">abc</span><span class="o">.</span><span class="n">Mapping</span><span class="p">)</span>
  325. <span class="ow">and</span> <span class="n">args</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
  326. <span class="n">args</span> <span class="o">=</span> <span class="n">args</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
  327. <span class="bp">self</span><span class="o">.</span><span class="n">args</span> <span class="o">=</span> <span class="n">args</span>
  328. <span class="bp">self</span><span class="o">.</span><span class="n">levelname</span> <span class="o">=</span> <span class="n">getLevelName</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
  329. <span class="bp">self</span><span class="o">.</span><span class="n">levelno</span> <span class="o">=</span> <span class="n">level</span>
  330. <span class="bp">self</span><span class="o">.</span><span class="n">pathname</span> <span class="o">=</span> <span class="n">pathname</span>
  331. <span class="k">try</span><span class="p">:</span>
  332. <span class="bp">self</span><span class="o">.</span><span class="n">filename</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">basename</span><span class="p">(</span><span class="n">pathname</span><span class="p">)</span>
  333. <span class="bp">self</span><span class="o">.</span><span class="n">module</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">splitext</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filename</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
  334. <span class="k">except</span> <span class="p">(</span><span class="ne">TypeError</span><span class="p">,</span> <span class="ne">ValueError</span><span class="p">,</span> <span class="ne">AttributeError</span><span class="p">):</span>
  335. <span class="bp">self</span><span class="o">.</span><span class="n">filename</span> <span class="o">=</span> <span class="n">pathname</span>
  336. <span class="bp">self</span><span class="o">.</span><span class="n">module</span> <span class="o">=</span> <span class="s2">&quot;Unknown module&quot;</span>
  337. <span class="bp">self</span><span class="o">.</span><span class="n">exc_info</span> <span class="o">=</span> <span class="n">exc_info</span>
  338. <span class="bp">self</span><span class="o">.</span><span class="n">exc_text</span> <span class="o">=</span> <span class="kc">None</span> <span class="c1"># used to cache the traceback text</span>
  339. <span class="bp">self</span><span class="o">.</span><span class="n">stack_info</span> <span class="o">=</span> <span class="n">sinfo</span>
  340. <span class="bp">self</span><span class="o">.</span><span class="n">lineno</span> <span class="o">=</span> <span class="n">lineno</span>
  341. <span class="bp">self</span><span class="o">.</span><span class="n">funcName</span> <span class="o">=</span> <span class="n">func</span>
  342. <span class="bp">self</span><span class="o">.</span><span class="n">created</span> <span class="o">=</span> <span class="n">ct</span>
  343. <span class="bp">self</span><span class="o">.</span><span class="n">msecs</span> <span class="o">=</span> <span class="p">(</span><span class="n">ct</span> <span class="o">-</span> <span class="nb">int</span><span class="p">(</span><span class="n">ct</span><span class="p">))</span> <span class="o">*</span> <span class="mi">1000</span>
  344. <span class="bp">self</span><span class="o">.</span><span class="n">relativeCreated</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">created</span> <span class="o">-</span> <span class="n">_startTime</span><span class="p">)</span> <span class="o">*</span> <span class="mi">1000</span>
  345. <span class="k">if</span> <span class="n">logThreads</span><span class="p">:</span>
  346. <span class="bp">self</span><span class="o">.</span><span class="n">thread</span> <span class="o">=</span> <span class="n">threading</span><span class="o">.</span><span class="n">get_ident</span><span class="p">()</span>
  347. <span class="bp">self</span><span class="o">.</span><span class="n">threadName</span> <span class="o">=</span> <span class="n">threading</span><span class="o">.</span><span class="n">current_thread</span><span class="p">()</span><span class="o">.</span><span class="n">name</span>
  348. <span class="k">else</span><span class="p">:</span> <span class="c1"># pragma: no cover</span>
  349. <span class="bp">self</span><span class="o">.</span><span class="n">thread</span> <span class="o">=</span> <span class="kc">None</span>
  350. <span class="bp">self</span><span class="o">.</span><span class="n">threadName</span> <span class="o">=</span> <span class="kc">None</span>
  351. <span class="k">if</span> <span class="ow">not</span> <span class="n">logMultiprocessing</span><span class="p">:</span> <span class="c1"># pragma: no cover</span>
  352. <span class="bp">self</span><span class="o">.</span><span class="n">processName</span> <span class="o">=</span> <span class="kc">None</span>
  353. <span class="k">else</span><span class="p">:</span>
  354. <span class="bp">self</span><span class="o">.</span><span class="n">processName</span> <span class="o">=</span> <span class="s1">&#39;MainProcess&#39;</span>
  355. <span class="n">mp</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">modules</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;multiprocessing&#39;</span><span class="p">)</span>
  356. <span class="k">if</span> <span class="n">mp</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  357. <span class="c1"># Errors may occur if multiprocessing has not finished loading</span>
  358. <span class="c1"># yet - e.g. if a custom import hook causes third-party code</span>
  359. <span class="c1"># to run when multiprocessing calls import. See issue 8200</span>
  360. <span class="c1"># for an example</span>
  361. <span class="k">try</span><span class="p">:</span>
  362. <span class="bp">self</span><span class="o">.</span><span class="n">processName</span> <span class="o">=</span> <span class="n">mp</span><span class="o">.</span><span class="n">current_process</span><span class="p">()</span><span class="o">.</span><span class="n">name</span>
  363. <span class="k">except</span> <span class="ne">Exception</span><span class="p">:</span> <span class="c1">#pragma: no cover</span>
  364. <span class="k">pass</span>
  365. <span class="k">if</span> <span class="n">logProcesses</span> <span class="ow">and</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">os</span><span class="p">,</span> <span class="s1">&#39;getpid&#39;</span><span class="p">):</span>
  366. <span class="bp">self</span><span class="o">.</span><span class="n">process</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">getpid</span><span class="p">()</span>
  367. <span class="k">else</span><span class="p">:</span>
  368. <span class="bp">self</span><span class="o">.</span><span class="n">process</span> <span class="o">=</span> <span class="kc">None</span>
  369. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  370. <span class="k">return</span> <span class="s1">&#39;&lt;LogRecord: </span><span class="si">%s</span><span class="s1">, </span><span class="si">%s</span><span class="s1">, </span><span class="si">%s</span><span class="s1">, </span><span class="si">%s</span><span class="s1">, &quot;</span><span class="si">%s</span><span class="s1">&quot;&gt;&#39;</span><span class="o">%</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">levelno</span><span class="p">,</span>
  371. <span class="bp">self</span><span class="o">.</span><span class="n">pathname</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">lineno</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">msg</span><span class="p">)</span>
  372. <span class="k">def</span> <span class="nf">getMessage</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  373. <span class="sd">&quot;&quot;&quot;</span>
  374. <span class="sd"> Return the message for this LogRecord.</span>
  375. <span class="sd"> Return the message for this LogRecord after merging any user-supplied</span>
  376. <span class="sd"> arguments with the message.</span>
  377. <span class="sd"> &quot;&quot;&quot;</span>
  378. <span class="n">msg</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">msg</span><span class="p">)</span>
  379. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span><span class="p">:</span>
  380. <span class="n">msg</span> <span class="o">=</span> <span class="n">msg</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">args</span>
  381. <span class="k">return</span> <span class="n">msg</span>
  382. <span class="c1">#</span>
  383. <span class="c1"># Determine which class to use when instantiating log records.</span>
  384. <span class="c1">#</span>
  385. <span class="n">_logRecordFactory</span> <span class="o">=</span> <span class="n">LogRecord</span>
  386. <span class="k">def</span> <span class="nf">setLogRecordFactory</span><span class="p">(</span><span class="n">factory</span><span class="p">):</span>
  387. <span class="sd">&quot;&quot;&quot;</span>
  388. <span class="sd"> Set the factory to be used when instantiating a log record.</span>
  389. <span class="sd"> :param factory: A callable which will be called to instantiate</span>
  390. <span class="sd"> a log record.</span>
  391. <span class="sd"> &quot;&quot;&quot;</span>
  392. <span class="k">global</span> <span class="n">_logRecordFactory</span>
  393. <span class="n">_logRecordFactory</span> <span class="o">=</span> <span class="n">factory</span>
  394. <span class="k">def</span> <span class="nf">getLogRecordFactory</span><span class="p">():</span>
  395. <span class="sd">&quot;&quot;&quot;</span>
  396. <span class="sd"> Return the factory to be used when instantiating a log record.</span>
  397. <span class="sd"> &quot;&quot;&quot;</span>
  398. <span class="k">return</span> <span class="n">_logRecordFactory</span>
  399. <span class="k">def</span> <span class="nf">makeLogRecord</span><span class="p">(</span><span class="nb">dict</span><span class="p">):</span>
  400. <span class="sd">&quot;&quot;&quot;</span>
  401. <span class="sd"> Make a LogRecord whose attributes are defined by the specified dictionary,</span>
  402. <span class="sd"> This function is useful for converting a logging event received over</span>
  403. <span class="sd"> a socket connection (which is sent as a dictionary) into a LogRecord</span>
  404. <span class="sd"> instance.</span>
  405. <span class="sd"> &quot;&quot;&quot;</span>
  406. <span class="n">rv</span> <span class="o">=</span> <span class="n">_logRecordFactory</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="s2">&quot;&quot;</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;&quot;</span><span class="p">,</span> <span class="p">(),</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
  407. <span class="n">rv</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="nb">dict</span><span class="p">)</span>
  408. <span class="k">return</span> <span class="n">rv</span>
  409. <span class="c1">#---------------------------------------------------------------------------</span>
  410. <span class="c1"># Formatter classes and functions</span>
  411. <span class="c1">#---------------------------------------------------------------------------</span>
  412. <span class="n">_str_formatter</span> <span class="o">=</span> <span class="n">StrFormatter</span><span class="p">()</span>
  413. <span class="k">del</span> <span class="n">StrFormatter</span>
  414. <span class="k">class</span> <span class="nc">PercentStyle</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
  415. <span class="n">default_format</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="si">%(message)s</span><span class="s1">&#39;</span>
  416. <span class="n">asctime_format</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="si">%(asctime)s</span><span class="s1">&#39;</span>
  417. <span class="n">asctime_search</span> <span class="o">=</span> <span class="s1">&#39;%(asctime)&#39;</span>
  418. <span class="n">validation_pattern</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="sa">r</span><span class="s1">&#39;%\(\w+\)[#0+ -]*(\*|\d+)?(\.(\*|\d+))?[diouxefgcrsa%]&#39;</span><span class="p">,</span> <span class="n">re</span><span class="o">.</span><span class="n">I</span><span class="p">)</span>
  419. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">fmt</span><span class="p">):</span>
  420. <span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span> <span class="o">=</span> <span class="n">fmt</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">default_format</span>
  421. <span class="k">def</span> <span class="nf">usesTime</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  422. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">asctime_search</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="mi">0</span>
  423. <span class="k">def</span> <span class="nf">validate</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  424. <span class="sd">&quot;&quot;&quot;Validate the input format, ensure it matches the correct style&quot;&quot;&quot;</span>
  425. <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">validation_pattern</span><span class="o">.</span><span class="n">search</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span><span class="p">):</span>
  426. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Invalid format &#39;</span><span class="si">%s</span><span class="s2">&#39; for &#39;</span><span class="si">%s</span><span class="s2">&#39; style&quot;</span> <span class="o">%</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">default_format</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span>
  427. <span class="k">def</span> <span class="nf">_format</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
  428. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span> <span class="o">%</span> <span class="n">record</span><span class="o">.</span><span class="vm">__dict__</span>
  429. <span class="k">def</span> <span class="nf">format</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
  430. <span class="k">try</span><span class="p">:</span>
  431. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_format</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
  432. <span class="k">except</span> <span class="ne">KeyError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
  433. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;Formatting field not found in record: </span><span class="si">%s</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="n">e</span><span class="p">)</span>
  434. <span class="k">class</span> <span class="nc">StrFormatStyle</span><span class="p">(</span><span class="n">PercentStyle</span><span class="p">):</span>
  435. <span class="n">default_format</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="si">{message}</span><span class="s1">&#39;</span>
  436. <span class="n">asctime_format</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="si">{asctime}</span><span class="s1">&#39;</span>
  437. <span class="n">asctime_search</span> <span class="o">=</span> <span class="s1">&#39;{asctime&#39;</span>
  438. <span class="n">fmt_spec</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="sa">r</span><span class="s1">&#39;^(.?[&lt;&gt;=^])?[+ -]?#?0?(\d+|{\w+})?[,_]?(\.(\d+|{\w+}))?[bcdefgnosx%]?$&#39;</span><span class="p">,</span> <span class="n">re</span><span class="o">.</span><span class="n">I</span><span class="p">)</span>
  439. <span class="n">field_spec</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="sa">r</span><span class="s1">&#39;^(\d+|\w+)(\.\w+|\[[^]]+\])*$&#39;</span><span class="p">)</span>
  440. <span class="k">def</span> <span class="nf">_format</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
  441. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="o">**</span><span class="n">record</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)</span>
  442. <span class="k">def</span> <span class="nf">validate</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  443. <span class="sd">&quot;&quot;&quot;Validate the input format, ensure it is the correct string formatting style&quot;&quot;&quot;</span>
  444. <span class="n">fields</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span>
  445. <span class="k">try</span><span class="p">:</span>
  446. <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">fieldname</span><span class="p">,</span> <span class="n">spec</span><span class="p">,</span> <span class="n">conversion</span> <span class="ow">in</span> <span class="n">_str_formatter</span><span class="o">.</span><span class="n">parse</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span><span class="p">):</span>
  447. <span class="k">if</span> <span class="n">fieldname</span><span class="p">:</span>
  448. <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">field_spec</span><span class="o">.</span><span class="n">match</span><span class="p">(</span><span class="n">fieldname</span><span class="p">):</span>
  449. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;invalid field name/expression: </span><span class="si">%r</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="n">fieldname</span><span class="p">)</span>
  450. <span class="n">fields</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">fieldname</span><span class="p">)</span>
  451. <span class="k">if</span> <span class="n">conversion</span> <span class="ow">and</span> <span class="n">conversion</span> <span class="ow">not</span> <span class="ow">in</span> <span class="s1">&#39;rsa&#39;</span><span class="p">:</span>
  452. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;invalid conversion: </span><span class="si">%r</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="n">conversion</span><span class="p">)</span>
  453. <span class="k">if</span> <span class="n">spec</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">fmt_spec</span><span class="o">.</span><span class="n">match</span><span class="p">(</span><span class="n">spec</span><span class="p">):</span>
  454. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;bad specifier: </span><span class="si">%r</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="n">spec</span><span class="p">)</span>
  455. <span class="k">except</span> <span class="ne">ValueError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
  456. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;invalid format: </span><span class="si">%s</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="n">e</span><span class="p">)</span>
  457. <span class="k">if</span> <span class="ow">not</span> <span class="n">fields</span><span class="p">:</span>
  458. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;invalid format: no fields&#39;</span><span class="p">)</span>
  459. <span class="k">class</span> <span class="nc">StringTemplateStyle</span><span class="p">(</span><span class="n">PercentStyle</span><span class="p">):</span>
  460. <span class="n">default_format</span> <span class="o">=</span> <span class="s1">&#39;$</span><span class="si">{message}</span><span class="s1">&#39;</span>
  461. <span class="n">asctime_format</span> <span class="o">=</span> <span class="s1">&#39;$</span><span class="si">{asctime}</span><span class="s1">&#39;</span>
  462. <span class="n">asctime_search</span> <span class="o">=</span> <span class="s1">&#39;$</span><span class="si">{asctime}</span><span class="s1">&#39;</span>
  463. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">fmt</span><span class="p">):</span>
  464. <span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span> <span class="o">=</span> <span class="n">fmt</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">default_format</span>
  465. <span class="bp">self</span><span class="o">.</span><span class="n">_tpl</span> <span class="o">=</span> <span class="n">Template</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span><span class="p">)</span>
  466. <span class="k">def</span> <span class="nf">usesTime</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  467. <span class="n">fmt</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span>
  468. <span class="k">return</span> <span class="n">fmt</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s1">&#39;$asctime&#39;</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">fmt</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">asctime_format</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="mi">0</span>
  469. <span class="k">def</span> <span class="nf">validate</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  470. <span class="n">pattern</span> <span class="o">=</span> <span class="n">Template</span><span class="o">.</span><span class="n">pattern</span>
  471. <span class="n">fields</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span>
  472. <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="n">pattern</span><span class="o">.</span><span class="n">finditer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span><span class="p">):</span>
  473. <span class="n">d</span> <span class="o">=</span> <span class="n">m</span><span class="o">.</span><span class="n">groupdict</span><span class="p">()</span>
  474. <span class="k">if</span> <span class="n">d</span><span class="p">[</span><span class="s1">&#39;named&#39;</span><span class="p">]:</span>
  475. <span class="n">fields</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">d</span><span class="p">[</span><span class="s1">&#39;named&#39;</span><span class="p">])</span>
  476. <span class="k">elif</span> <span class="n">d</span><span class="p">[</span><span class="s1">&#39;braced&#39;</span><span class="p">]:</span>
  477. <span class="n">fields</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">d</span><span class="p">[</span><span class="s1">&#39;braced&#39;</span><span class="p">])</span>
  478. <span class="k">elif</span> <span class="n">m</span><span class="o">.</span><span class="n">group</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">==</span> <span class="s1">&#39;$&#39;</span><span class="p">:</span>
  479. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;invalid format: bare </span><span class="se">\&#39;</span><span class="s1">$</span><span class="se">\&#39;</span><span class="s1"> not allowed&#39;</span><span class="p">)</span>
  480. <span class="k">if</span> <span class="ow">not</span> <span class="n">fields</span><span class="p">:</span>
  481. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;invalid format: no fields&#39;</span><span class="p">)</span>
  482. <span class="k">def</span> <span class="nf">_format</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
  483. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tpl</span><span class="o">.</span><span class="n">substitute</span><span class="p">(</span><span class="o">**</span><span class="n">record</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)</span>
  484. <span class="n">BASIC_FORMAT</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="si">%(levelname)s</span><span class="s2">:</span><span class="si">%(name)s</span><span class="s2">:</span><span class="si">%(message)s</span><span class="s2">&quot;</span>
  485. <span class="n">_STYLES</span> <span class="o">=</span> <span class="p">{</span>
  486. <span class="s1">&#39;%&#39;</span><span class="p">:</span> <span class="p">(</span><span class="n">PercentStyle</span><span class="p">,</span> <span class="n">BASIC_FORMAT</span><span class="p">),</span>
  487. <span class="s1">&#39;{&#39;</span><span class="p">:</span> <span class="p">(</span><span class="n">StrFormatStyle</span><span class="p">,</span> <span class="s1">&#39;</span><span class="si">{levelname}</span><span class="s1">:</span><span class="si">{name}</span><span class="s1">:</span><span class="si">{message}</span><span class="s1">&#39;</span><span class="p">),</span>
  488. <span class="s1">&#39;$&#39;</span><span class="p">:</span> <span class="p">(</span><span class="n">StringTemplateStyle</span><span class="p">,</span> <span class="s1">&#39;$</span><span class="si">{levelname}</span><span class="s1">:$</span><span class="si">{name}</span><span class="s1">:$</span><span class="si">{message}</span><span class="s1">&#39;</span><span class="p">),</span>
  489. <span class="p">}</span>
  490. <span class="k">class</span> <span class="nc">Formatter</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
  491. <span class="sd">&quot;&quot;&quot;</span>
  492. <span class="sd"> Formatter instances are used to convert a LogRecord to text.</span>
  493. <span class="sd"> Formatters need to know how a LogRecord is constructed. They are</span>
  494. <span class="sd"> responsible for converting a LogRecord to (usually) a string which can</span>
  495. <span class="sd"> be interpreted by either a human or an external system. The base Formatter</span>
  496. <span class="sd"> allows a formatting string to be specified. If none is supplied, the</span>
  497. <span class="sd"> style-dependent default value, &quot;%(message)s&quot;, &quot;{message}&quot;, or</span>
  498. <span class="sd"> &quot;${message}&quot;, is used.</span>
  499. <span class="sd"> The Formatter can be initialized with a format string which makes use of</span>
  500. <span class="sd"> knowledge of the LogRecord attributes - e.g. the default value mentioned</span>
  501. <span class="sd"> above makes use of the fact that the user&#39;s message and arguments are pre-</span>
  502. <span class="sd"> formatted into a LogRecord&#39;s message attribute. Currently, the useful</span>
  503. <span class="sd"> attributes in a LogRecord are described by:</span>
  504. <span class="sd"> %(name)s Name of the logger (logging channel)</span>
  505. <span class="sd"> %(levelno)s Numeric logging level for the message (DEBUG, INFO,</span>
  506. <span class="sd"> WARNING, ERROR, CRITICAL)</span>
  507. <span class="sd"> %(levelname)s Text logging level for the message (&quot;DEBUG&quot;, &quot;INFO&quot;,</span>
  508. <span class="sd"> &quot;WARNING&quot;, &quot;ERROR&quot;, &quot;CRITICAL&quot;)</span>
  509. <span class="sd"> %(pathname)s Full pathname of the source file where the logging</span>
  510. <span class="sd"> call was issued (if available)</span>
  511. <span class="sd"> %(filename)s Filename portion of pathname</span>
  512. <span class="sd"> %(module)s Module (name portion of filename)</span>
  513. <span class="sd"> %(lineno)d Source line number where the logging call was issued</span>
  514. <span class="sd"> (if available)</span>
  515. <span class="sd"> %(funcName)s Function name</span>
  516. <span class="sd"> %(created)f Time when the LogRecord was created (time.time()</span>
  517. <span class="sd"> return value)</span>
  518. <span class="sd"> %(asctime)s Textual time when the LogRecord was created</span>
  519. <span class="sd"> %(msecs)d Millisecond portion of the creation time</span>
  520. <span class="sd"> %(relativeCreated)d Time in milliseconds when the LogRecord was created,</span>
  521. <span class="sd"> relative to the time the logging module was loaded</span>
  522. <span class="sd"> (typically at application startup time)</span>
  523. <span class="sd"> %(thread)d Thread ID (if available)</span>
  524. <span class="sd"> %(threadName)s Thread name (if available)</span>
  525. <span class="sd"> %(process)d Process ID (if available)</span>
  526. <span class="sd"> %(message)s The result of record.getMessage(), computed just as</span>
  527. <span class="sd"> the record is emitted</span>
  528. <span class="sd"> &quot;&quot;&quot;</span>
  529. <span class="n">converter</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">localtime</span>
  530. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">fmt</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">datefmt</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">style</span><span class="o">=</span><span class="s1">&#39;%&#39;</span><span class="p">,</span> <span class="n">validate</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
  531. <span class="sd">&quot;&quot;&quot;</span>
  532. <span class="sd"> Initialize the formatter with specified format strings.</span>
  533. <span class="sd"> Initialize the formatter either with the specified format string, or a</span>
  534. <span class="sd"> default as described above. Allow for specialized date formatting with</span>
  535. <span class="sd"> the optional datefmt argument. If datefmt is omitted, you get an</span>
  536. <span class="sd"> ISO8601-like (or RFC 3339-like) format.</span>
  537. <span class="sd"> Use a style parameter of &#39;%&#39;, &#39;{&#39; or &#39;$&#39; to specify that you want to</span>
  538. <span class="sd"> use one of %-formatting, :meth:`str.format` (``{}``) formatting or</span>
  539. <span class="sd"> :class:`string.Template` formatting in your format string.</span>
  540. <span class="sd"> .. versionchanged:: 3.2</span>
  541. <span class="sd"> Added the ``style`` parameter.</span>
  542. <span class="sd"> &quot;&quot;&quot;</span>
  543. <span class="k">if</span> <span class="n">style</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">_STYLES</span><span class="p">:</span>
  544. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;Style must be one of: </span><span class="si">%s</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="s1">&#39;,&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span>
  545. <span class="n">_STYLES</span><span class="o">.</span><span class="n">keys</span><span class="p">()))</span>
  546. <span class="bp">self</span><span class="o">.</span><span class="n">_style</span> <span class="o">=</span> <span class="n">_STYLES</span><span class="p">[</span><span class="n">style</span><span class="p">][</span><span class="mi">0</span><span class="p">](</span><span class="n">fmt</span><span class="p">)</span>
  547. <span class="k">if</span> <span class="n">validate</span><span class="p">:</span>
  548. <span class="bp">self</span><span class="o">.</span><span class="n">_style</span><span class="o">.</span><span class="n">validate</span><span class="p">()</span>
  549. <span class="bp">self</span><span class="o">.</span><span class="n">_fmt</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_style</span><span class="o">.</span><span class="n">_fmt</span>
  550. <span class="bp">self</span><span class="o">.</span><span class="n">datefmt</span> <span class="o">=</span> <span class="n">datefmt</span>
  551. <span class="n">default_time_format</span> <span class="o">=</span> <span class="s1">&#39;%Y-%m-</span><span class="si">%d</span><span class="s1"> %H:%M:%S&#39;</span>
  552. <span class="n">default_msec_format</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="si">%s</span><span class="s1">,</span><span class="si">%03d</span><span class="s1">&#39;</span>
  553. <span class="k">def</span> <span class="nf">formatTime</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">,</span> <span class="n">datefmt</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  554. <span class="sd">&quot;&quot;&quot;</span>
  555. <span class="sd"> Return the creation time of the specified LogRecord as formatted text.</span>
  556. <span class="sd"> This method should be called from format() by a formatter which</span>
  557. <span class="sd"> wants to make use of a formatted time. This method can be overridden</span>
  558. <span class="sd"> in formatters to provide for any specific requirement, but the</span>
  559. <span class="sd"> basic behaviour is as follows: if datefmt (a string) is specified,</span>
  560. <span class="sd"> it is used with time.strftime() to format the creation time of the</span>
  561. <span class="sd"> record. Otherwise, an ISO8601-like (or RFC 3339-like) format is used.</span>
  562. <span class="sd"> The resulting string is returned. This function uses a user-configurable</span>
  563. <span class="sd"> function to convert the creation time to a tuple. By default,</span>
  564. <span class="sd"> time.localtime() is used; to change this for a particular formatter</span>
  565. <span class="sd"> instance, set the &#39;converter&#39; attribute to a function with the same</span>
  566. <span class="sd"> signature as time.localtime() or time.gmtime(). To change it for all</span>
  567. <span class="sd"> formatters, for example if you want all logging times to be shown in GMT,</span>
  568. <span class="sd"> set the &#39;converter&#39; attribute in the Formatter class.</span>
  569. <span class="sd"> &quot;&quot;&quot;</span>
  570. <span class="n">ct</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">converter</span><span class="p">(</span><span class="n">record</span><span class="o">.</span><span class="n">created</span><span class="p">)</span>
  571. <span class="k">if</span> <span class="n">datefmt</span><span class="p">:</span>
  572. <span class="n">s</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">strftime</span><span class="p">(</span><span class="n">datefmt</span><span class="p">,</span> <span class="n">ct</span><span class="p">)</span>
  573. <span class="k">else</span><span class="p">:</span>
  574. <span class="n">s</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">strftime</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">default_time_format</span><span class="p">,</span> <span class="n">ct</span><span class="p">)</span>
  575. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">default_msec_format</span><span class="p">:</span>
  576. <span class="n">s</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">default_msec_format</span> <span class="o">%</span> <span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="n">record</span><span class="o">.</span><span class="n">msecs</span><span class="p">)</span>
  577. <span class="k">return</span> <span class="n">s</span>
  578. <span class="k">def</span> <span class="nf">formatException</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ei</span><span class="p">):</span>
  579. <span class="sd">&quot;&quot;&quot;</span>
  580. <span class="sd"> Format and return the specified exception information as a string.</span>
  581. <span class="sd"> This default implementation just uses</span>
  582. <span class="sd"> traceback.print_exception()</span>
  583. <span class="sd"> &quot;&quot;&quot;</span>
  584. <span class="n">sio</span> <span class="o">=</span> <span class="n">io</span><span class="o">.</span><span class="n">StringIO</span><span class="p">()</span>
  585. <span class="n">tb</span> <span class="o">=</span> <span class="n">ei</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
  586. <span class="c1"># See issues #9427, #1553375. Commented out for now.</span>
  587. <span class="c1">#if getattr(self, &#39;fullstack&#39;, False):</span>
  588. <span class="c1"># traceback.print_stack(tb.tb_frame.f_back, file=sio)</span>
  589. <span class="n">traceback</span><span class="o">.</span><span class="n">print_exception</span><span class="p">(</span><span class="n">ei</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">ei</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">tb</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="n">sio</span><span class="p">)</span>
  590. <span class="n">s</span> <span class="o">=</span> <span class="n">sio</span><span class="o">.</span><span class="n">getvalue</span><span class="p">()</span>
  591. <span class="n">sio</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
  592. <span class="k">if</span> <span class="n">s</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">:]</span> <span class="o">==</span> <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">:</span>
  593. <span class="n">s</span> <span class="o">=</span> <span class="n">s</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
  594. <span class="k">return</span> <span class="n">s</span>
  595. <span class="k">def</span> <span class="nf">usesTime</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  596. <span class="sd">&quot;&quot;&quot;</span>
  597. <span class="sd"> Check if the format uses the creation time of the record.</span>
  598. <span class="sd"> &quot;&quot;&quot;</span>
  599. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_style</span><span class="o">.</span><span class="n">usesTime</span><span class="p">()</span>
  600. <span class="k">def</span> <span class="nf">formatMessage</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
  601. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_style</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
  602. <span class="k">def</span> <span class="nf">formatStack</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">stack_info</span><span class="p">):</span>
  603. <span class="sd">&quot;&quot;&quot;</span>
  604. <span class="sd"> This method is provided as an extension point for specialized</span>
  605. <span class="sd"> formatting of stack information.</span>
  606. <span class="sd"> The input data is a string as returned from a call to</span>
  607. <span class="sd"> :func:`traceback.print_stack`, but with the last trailing newline</span>
  608. <span class="sd"> removed.</span>
  609. <span class="sd"> The base implementation just returns the value passed in.</span>
  610. <span class="sd"> &quot;&quot;&quot;</span>
  611. <span class="k">return</span> <span class="n">stack_info</span>
  612. <span class="k">def</span> <span class="nf">format</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
  613. <span class="sd">&quot;&quot;&quot;</span>
  614. <span class="sd"> Format the specified record as text.</span>
  615. <span class="sd"> The record&#39;s attribute dictionary is used as the operand to a</span>
  616. <span class="sd"> string formatting operation which yields the returned string.</span>
  617. <span class="sd"> Before formatting the dictionary, a couple of preparatory steps</span>
  618. <span class="sd"> are carried out. The message attribute of the record is computed</span>
  619. <span class="sd"> using LogRecord.getMessage(). If the formatting string uses the</span>
  620. <span class="sd"> time (as determined by a call to usesTime(), formatTime() is</span>
  621. <span class="sd"> called to format the event time. If there is exception information,</span>
  622. <span class="sd"> it is formatted using formatException() and appended to the message.</span>
  623. <span class="sd"> &quot;&quot;&quot;</span>
  624. <span class="n">record</span><span class="o">.</span><span class="n">message</span> <span class="o">=</span> <span class="n">record</span><span class="o">.</span><span class="n">getMessage</span><span class="p">()</span>
  625. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">usesTime</span><span class="p">():</span>
  626. <span class="n">record</span><span class="o">.</span><span class="n">asctime</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">formatTime</span><span class="p">(</span><span class="n">record</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">datefmt</span><span class="p">)</span>
  627. <span class="n">s</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">formatMessage</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
  628. <span class="k">if</span> <span class="n">record</span><span class="o">.</span><span class="n">exc_info</span><span class="p">:</span>
  629. <span class="c1"># Cache the traceback text to avoid converting it multiple times</span>
  630. <span class="c1"># (it&#39;s constant anyway)</span>
  631. <span class="k">if</span> <span class="ow">not</span> <span class="n">record</span><span class="o">.</span><span class="n">exc_text</span><span class="p">:</span>
  632. <span class="n">record</span><span class="o">.</span><span class="n">exc_text</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">formatException</span><span class="p">(</span><span class="n">record</span><span class="o">.</span><span class="n">exc_info</span><span class="p">)</span>
  633. <span class="k">if</span> <span class="n">record</span><span class="o">.</span><span class="n">exc_text</span><span class="p">:</span>
  634. <span class="k">if</span> <span class="n">s</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">:]</span> <span class="o">!=</span> <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">:</span>
  635. <span class="n">s</span> <span class="o">=</span> <span class="n">s</span> <span class="o">+</span> <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span>
  636. <span class="n">s</span> <span class="o">=</span> <span class="n">s</span> <span class="o">+</span> <span class="n">record</span><span class="o">.</span><span class="n">exc_text</span>
  637. <span class="k">if</span> <span class="n">record</span><span class="o">.</span><span class="n">stack_info</span><span class="p">:</span>
  638. <span class="k">if</span> <span class="n">s</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">:]</span> <span class="o">!=</span> <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">:</span>
  639. <span class="n">s</span> <span class="o">=</span> <span class="n">s</span> <span class="o">+</span> <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span>
  640. <span class="n">s</span> <span class="o">=</span> <span class="n">s</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">formatStack</span><span class="p">(</span><span class="n">record</span><span class="o">.</span><span class="n">stack_info</span><span class="p">)</span>
  641. <span class="k">return</span> <span class="n">s</span>
  642. <span class="c1">#</span>
  643. <span class="c1"># The default formatter to use when no other is specified</span>
  644. <span class="c1">#</span>
  645. <span class="n">_defaultFormatter</span> <span class="o">=</span> <span class="n">Formatter</span><span class="p">()</span>
  646. <span class="k">class</span> <span class="nc">BufferingFormatter</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
  647. <span class="sd">&quot;&quot;&quot;</span>
  648. <span class="sd"> A formatter suitable for formatting a number of records.</span>
  649. <span class="sd"> &quot;&quot;&quot;</span>
  650. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">linefmt</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  651. <span class="sd">&quot;&quot;&quot;</span>
  652. <span class="sd"> Optionally specify a formatter which will be used to format each</span>
  653. <span class="sd"> individual record.</span>
  654. <span class="sd"> &quot;&quot;&quot;</span>
  655. <span class="k">if</span> <span class="n">linefmt</span><span class="p">:</span>
  656. <span class="bp">self</span><span class="o">.</span><span class="n">linefmt</span> <span class="o">=</span> <span class="n">linefmt</span>
  657. <span class="k">else</span><span class="p">:</span>
  658. <span class="bp">self</span><span class="o">.</span><span class="n">linefmt</span> <span class="o">=</span> <span class="n">_defaultFormatter</span>
  659. <span class="k">def</span> <span class="nf">formatHeader</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">records</span><span class="p">):</span>
  660. <span class="sd">&quot;&quot;&quot;</span>
  661. <span class="sd"> Return the header string for the specified records.</span>
  662. <span class="sd"> &quot;&quot;&quot;</span>
  663. <span class="k">return</span> <span class="s2">&quot;&quot;</span>
  664. <span class="k">def</span> <span class="nf">formatFooter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">records</span><span class="p">):</span>
  665. <span class="sd">&quot;&quot;&quot;</span>
  666. <span class="sd"> Return the footer string for the specified records.</span>
  667. <span class="sd"> &quot;&quot;&quot;</span>
  668. <span class="k">return</span> <span class="s2">&quot;&quot;</span>
  669. <span class="k">def</span> <span class="nf">format</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">records</span><span class="p">):</span>
  670. <span class="sd">&quot;&quot;&quot;</span>
  671. <span class="sd"> Format the specified records and return the result as a string.</span>
  672. <span class="sd"> &quot;&quot;&quot;</span>
  673. <span class="n">rv</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span>
  674. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">records</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
  675. <span class="n">rv</span> <span class="o">=</span> <span class="n">rv</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">formatHeader</span><span class="p">(</span><span class="n">records</span><span class="p">)</span>
  676. <span class="k">for</span> <span class="n">record</span> <span class="ow">in</span> <span class="n">records</span><span class="p">:</span>
  677. <span class="n">rv</span> <span class="o">=</span> <span class="n">rv</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">linefmt</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
  678. <span class="n">rv</span> <span class="o">=</span> <span class="n">rv</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">formatFooter</span><span class="p">(</span><span class="n">records</span><span class="p">)</span>
  679. <span class="k">return</span> <span class="n">rv</span>
  680. <span class="c1">#---------------------------------------------------------------------------</span>
  681. <span class="c1"># Filter classes and functions</span>
  682. <span class="c1">#---------------------------------------------------------------------------</span>
  683. <span class="k">class</span> <span class="nc">Filter</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
  684. <span class="sd">&quot;&quot;&quot;</span>
  685. <span class="sd"> Filter instances are used to perform arbitrary filtering of LogRecords.</span>
  686. <span class="sd"> Loggers and Handlers can optionally use Filter instances to filter</span>
  687. <span class="sd"> records as desired. The base filter class only allows events which are</span>
  688. <span class="sd"> below a certain point in the logger hierarchy. For example, a filter</span>
  689. <span class="sd"> initialized with &quot;A.B&quot; will allow events logged by loggers &quot;A.B&quot;,</span>
  690. <span class="sd"> &quot;A.B.C&quot;, &quot;A.B.C.D&quot;, &quot;A.B.D&quot; etc. but not &quot;A.BB&quot;, &quot;B.A.B&quot; etc. If</span>
  691. <span class="sd"> initialized with the empty string, all events are passed.</span>
  692. <span class="sd"> &quot;&quot;&quot;</span>
  693. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">&#39;&#39;</span><span class="p">):</span>
  694. <span class="sd">&quot;&quot;&quot;</span>
  695. <span class="sd"> Initialize a filter.</span>
  696. <span class="sd"> Initialize with the name of the logger which, together with its</span>
  697. <span class="sd"> children, will have its events allowed through the filter. If no</span>
  698. <span class="sd"> name is specified, allow every event.</span>
  699. <span class="sd"> &quot;&quot;&quot;</span>
  700. <span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">name</span>
  701. <span class="bp">self</span><span class="o">.</span><span class="n">nlen</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
  702. <span class="k">def</span> <span class="nf">filter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
  703. <span class="sd">&quot;&quot;&quot;</span>
  704. <span class="sd"> Determine if the specified record is to be logged.</span>
  705. <span class="sd"> Returns True if the record should be logged, or False otherwise.</span>
  706. <span class="sd"> If deemed appropriate, the record may be modified in-place.</span>
  707. <span class="sd"> &quot;&quot;&quot;</span>
  708. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">nlen</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  709. <span class="k">return</span> <span class="kc">True</span>
  710. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">==</span> <span class="n">record</span><span class="o">.</span><span class="n">name</span><span class="p">:</span>
  711. <span class="k">return</span> <span class="kc">True</span>
  712. <span class="k">elif</span> <span class="n">record</span><span class="o">.</span><span class="n">name</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">nlen</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
  713. <span class="k">return</span> <span class="kc">False</span>
  714. <span class="k">return</span> <span class="p">(</span><span class="n">record</span><span class="o">.</span><span class="n">name</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">nlen</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;.&quot;</span><span class="p">)</span>
  715. <span class="k">class</span> <span class="nc">Filterer</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
  716. <span class="sd">&quot;&quot;&quot;</span>
  717. <span class="sd"> A base class for loggers and handlers which allows them to share</span>
  718. <span class="sd"> common code.</span>
  719. <span class="sd"> &quot;&quot;&quot;</span>
  720. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  721. <span class="sd">&quot;&quot;&quot;</span>
  722. <span class="sd"> Initialize the list of filters to be an empty list.</span>
  723. <span class="sd"> &quot;&quot;&quot;</span>
  724. <span class="bp">self</span><span class="o">.</span><span class="n">filters</span> <span class="o">=</span> <span class="p">[]</span>
  725. <span class="k">def</span> <span class="nf">addFilter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">filter</span><span class="p">):</span>
  726. <span class="sd">&quot;&quot;&quot;</span>
  727. <span class="sd"> Add the specified filter to this handler.</span>
  728. <span class="sd"> &quot;&quot;&quot;</span>
  729. <span class="k">if</span> <span class="ow">not</span> <span class="p">(</span><span class="nb">filter</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="p">):</span>
  730. <span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">filter</span><span class="p">)</span>
  731. <span class="k">def</span> <span class="nf">removeFilter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">filter</span><span class="p">):</span>
  732. <span class="sd">&quot;&quot;&quot;</span>
  733. <span class="sd"> Remove the specified filter from this handler.</span>
  734. <span class="sd"> &quot;&quot;&quot;</span>
  735. <span class="k">if</span> <span class="nb">filter</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="p">:</span>
  736. <span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="nb">filter</span><span class="p">)</span>
  737. <span class="k">def</span> <span class="nf">filter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
  738. <span class="sd">&quot;&quot;&quot;</span>
  739. <span class="sd"> Determine if a record is loggable by consulting all the filters.</span>
  740. <span class="sd"> The default is to allow the record to be logged; any filter can veto</span>
  741. <span class="sd"> this and the record is then dropped. Returns a zero value if a record</span>
  742. <span class="sd"> is to be dropped, else non-zero.</span>
  743. <span class="sd"> .. versionchanged:: 3.2</span>
  744. <span class="sd"> Allow filters to be just callables.</span>
  745. <span class="sd"> &quot;&quot;&quot;</span>
  746. <span class="n">rv</span> <span class="o">=</span> <span class="kc">True</span>
  747. <span class="k">for</span> <span class="n">f</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="p">:</span>
  748. <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">f</span><span class="p">,</span> <span class="s1">&#39;filter&#39;</span><span class="p">):</span>
  749. <span class="n">result</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
  750. <span class="k">else</span><span class="p">:</span>
  751. <span class="n">result</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">record</span><span class="p">)</span> <span class="c1"># assume callable - will raise if not</span>
  752. <span class="k">if</span> <span class="ow">not</span> <span class="n">result</span><span class="p">:</span>
  753. <span class="n">rv</span> <span class="o">=</span> <span class="kc">False</span>
  754. <span class="k">break</span>
  755. <span class="k">return</span> <span class="n">rv</span>
  756. <span class="c1">#---------------------------------------------------------------------------</span>
  757. <span class="c1"># Handler classes and functions</span>
  758. <span class="c1">#---------------------------------------------------------------------------</span>
  759. <span class="n">_handlers</span> <span class="o">=</span> <span class="n">weakref</span><span class="o">.</span><span class="n">WeakValueDictionary</span><span class="p">()</span> <span class="c1">#map of handler names to handlers</span>
  760. <span class="n">_handlerList</span> <span class="o">=</span> <span class="p">[]</span> <span class="c1"># added to allow handlers to be removed in reverse of order initialized</span>
  761. <span class="k">def</span> <span class="nf">_removeHandlerRef</span><span class="p">(</span><span class="n">wr</span><span class="p">):</span>
  762. <span class="sd">&quot;&quot;&quot;</span>
  763. <span class="sd"> Remove a handler reference from the internal cleanup list.</span>
  764. <span class="sd"> &quot;&quot;&quot;</span>
  765. <span class="c1"># This function can be called during module teardown, when globals are</span>
  766. <span class="c1"># set to None. It can also be called from another thread. So we need to</span>
  767. <span class="c1"># pre-emptively grab the necessary globals and check if they&#39;re None,</span>
  768. <span class="c1"># to prevent race conditions and failures during interpreter shutdown.</span>
  769. <span class="n">acquire</span><span class="p">,</span> <span class="n">release</span><span class="p">,</span> <span class="n">handlers</span> <span class="o">=</span> <span class="n">_acquireLock</span><span class="p">,</span> <span class="n">_releaseLock</span><span class="p">,</span> <span class="n">_handlerList</span>
  770. <span class="k">if</span> <span class="n">acquire</span> <span class="ow">and</span> <span class="n">release</span> <span class="ow">and</span> <span class="n">handlers</span><span class="p">:</span>
  771. <span class="n">acquire</span><span class="p">()</span>
  772. <span class="k">try</span><span class="p">:</span>
  773. <span class="k">if</span> <span class="n">wr</span> <span class="ow">in</span> <span class="n">handlers</span><span class="p">:</span>
  774. <span class="n">handlers</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="n">wr</span><span class="p">)</span>
  775. <span class="k">finally</span><span class="p">:</span>
  776. <span class="n">release</span><span class="p">()</span>
  777. <span class="k">def</span> <span class="nf">_addHandlerRef</span><span class="p">(</span><span class="n">handler</span><span class="p">):</span>
  778. <span class="sd">&quot;&quot;&quot;</span>
  779. <span class="sd"> Add a handler to the internal cleanup list using a weak reference.</span>
  780. <span class="sd"> &quot;&quot;&quot;</span>
  781. <span class="n">_acquireLock</span><span class="p">()</span>
  782. <span class="k">try</span><span class="p">:</span>
  783. <span class="n">_handlerList</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">weakref</span><span class="o">.</span><span class="n">ref</span><span class="p">(</span><span class="n">handler</span><span class="p">,</span> <span class="n">_removeHandlerRef</span><span class="p">))</span>
  784. <span class="k">finally</span><span class="p">:</span>
  785. <span class="n">_releaseLock</span><span class="p">()</span>
  786. <span class="k">class</span> <span class="nc">Handler</span><span class="p">(</span><span class="n">Filterer</span><span class="p">):</span>
  787. <span class="sd">&quot;&quot;&quot;</span>
  788. <span class="sd"> Handler instances dispatch logging events to specific destinations.</span>
  789. <span class="sd"> The base handler class. Acts as a placeholder which defines the Handler</span>
  790. <span class="sd"> interface. Handlers can optionally use Formatter instances to format</span>
  791. <span class="sd"> records as desired. By default, no formatter is specified; in this case,</span>
  792. <span class="sd"> the &#39;raw&#39; message as determined by record.message is logged.</span>
  793. <span class="sd"> &quot;&quot;&quot;</span>
  794. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="o">=</span><span class="n">NOTSET</span><span class="p">):</span>
  795. <span class="sd">&quot;&quot;&quot;</span>
  796. <span class="sd"> Initializes the instance - basically setting the formatter to None</span>
  797. <span class="sd"> and the filter list to empty.</span>
  798. <span class="sd"> &quot;&quot;&quot;</span>
  799. <span class="n">Filterer</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
  800. <span class="bp">self</span><span class="o">.</span><span class="n">_name</span> <span class="o">=</span> <span class="kc">None</span>
  801. <span class="bp">self</span><span class="o">.</span><span class="n">level</span> <span class="o">=</span> <span class="n">_checkLevel</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
  802. <span class="bp">self</span><span class="o">.</span><span class="n">formatter</span> <span class="o">=</span> <span class="kc">None</span>
  803. <span class="c1"># Add the handler to the global _handlerList (for cleanup on shutdown)</span>
  804. <span class="n">_addHandlerRef</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
  805. <span class="bp">self</span><span class="o">.</span><span class="n">createLock</span><span class="p">()</span>
  806. <span class="k">def</span> <span class="nf">get_name</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  807. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_name</span>
  808. <span class="k">def</span> <span class="nf">set_name</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
  809. <span class="n">_acquireLock</span><span class="p">()</span>
  810. <span class="k">try</span><span class="p">:</span>
  811. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_name</span> <span class="ow">in</span> <span class="n">_handlers</span><span class="p">:</span>
  812. <span class="k">del</span> <span class="n">_handlers</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_name</span><span class="p">]</span>
  813. <span class="bp">self</span><span class="o">.</span><span class="n">_name</span> <span class="o">=</span> <span class="n">name</span>
  814. <span class="k">if</span> <span class="n">name</span><span class="p">:</span>
  815. <span class="n">_handlers</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span>
  816. <span class="k">finally</span><span class="p">:</span>
  817. <span class="n">_releaseLock</span><span class="p">()</span>
  818. <span class="n">name</span> <span class="o">=</span> <span class="nb">property</span><span class="p">(</span><span class="n">get_name</span><span class="p">,</span> <span class="n">set_name</span><span class="p">)</span>
  819. <span class="k">def</span> <span class="nf">createLock</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  820. <span class="sd">&quot;&quot;&quot;</span>
  821. <span class="sd"> Acquire a thread lock for serializing access to the underlying I/O.</span>
  822. <span class="sd"> &quot;&quot;&quot;</span>
  823. <span class="bp">self</span><span class="o">.</span><span class="n">lock</span> <span class="o">=</span> <span class="n">threading</span><span class="o">.</span><span class="n">RLock</span><span class="p">()</span>
  824. <span class="n">_register_at_fork_reinit_lock</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
  825. <span class="k">def</span> <span class="nf">_at_fork_reinit</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  826. <span class="bp">self</span><span class="o">.</span><span class="n">lock</span><span class="o">.</span><span class="n">_at_fork_reinit</span><span class="p">()</span>
  827. <span class="k">def</span> <span class="nf">acquire</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  828. <span class="sd">&quot;&quot;&quot;</span>
  829. <span class="sd"> Acquire the I/O thread lock.</span>
  830. <span class="sd"> &quot;&quot;&quot;</span>
  831. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">lock</span><span class="p">:</span>
  832. <span class="bp">self</span><span class="o">.</span><span class="n">lock</span><span class="o">.</span><span class="n">acquire</span><span class="p">()</span>
  833. <span class="k">def</span> <span class="nf">release</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  834. <span class="sd">&quot;&quot;&quot;</span>
  835. <span class="sd"> Release the I/O thread lock.</span>
  836. <span class="sd"> &quot;&quot;&quot;</span>
  837. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">lock</span><span class="p">:</span>
  838. <span class="bp">self</span><span class="o">.</span><span class="n">lock</span><span class="o">.</span><span class="n">release</span><span class="p">()</span>
  839. <span class="k">def</span> <span class="nf">setLevel</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">):</span>
  840. <span class="sd">&quot;&quot;&quot;</span>
  841. <span class="sd"> Set the logging level of this handler. level must be an int or a str.</span>
  842. <span class="sd"> &quot;&quot;&quot;</span>
  843. <span class="bp">self</span><span class="o">.</span><span class="n">level</span> <span class="o">=</span> <span class="n">_checkLevel</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
  844. <span class="k">def</span> <span class="nf">format</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
  845. <span class="sd">&quot;&quot;&quot;</span>
  846. <span class="sd"> Format the specified record.</span>
  847. <span class="sd"> If a formatter is set, use it. Otherwise, use the default formatter</span>
  848. <span class="sd"> for the module.</span>
  849. <span class="sd"> &quot;&quot;&quot;</span>
  850. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">formatter</span><span class="p">:</span>
  851. <span class="n">fmt</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">formatter</span>
  852. <span class="k">else</span><span class="p">:</span>
  853. <span class="n">fmt</span> <span class="o">=</span> <span class="n">_defaultFormatter</span>
  854. <span class="k">return</span> <span class="n">fmt</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
  855. <span class="k">def</span> <span class="nf">emit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
  856. <span class="sd">&quot;&quot;&quot;</span>
  857. <span class="sd"> Do whatever it takes to actually log the specified logging record.</span>
  858. <span class="sd"> This version is intended to be implemented by subclasses and so</span>
  859. <span class="sd"> raises a NotImplementedError.</span>
  860. <span class="sd"> &quot;&quot;&quot;</span>
  861. <span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s1">&#39;emit must be implemented &#39;</span>
  862. <span class="s1">&#39;by Handler subclasses&#39;</span><span class="p">)</span>
  863. <span class="k">def</span> <span class="nf">handle</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
  864. <span class="sd">&quot;&quot;&quot;</span>
  865. <span class="sd"> Conditionally emit the specified logging record.</span>
  866. <span class="sd"> Emission depends on filters which may have been added to the handler.</span>
  867. <span class="sd"> Wrap the actual emission of the record with acquisition/release of</span>
  868. <span class="sd"> the I/O thread lock. Returns whether the filter passed the record for</span>
  869. <span class="sd"> emission.</span>
  870. <span class="sd"> &quot;&quot;&quot;</span>
  871. <span class="n">rv</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
  872. <span class="k">if</span> <span class="n">rv</span><span class="p">:</span>
  873. <span class="bp">self</span><span class="o">.</span><span class="n">acquire</span><span class="p">()</span>
  874. <span class="k">try</span><span class="p">:</span>
  875. <span class="bp">self</span><span class="o">.</span><span class="n">emit</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
  876. <span class="k">finally</span><span class="p">:</span>
  877. <span class="bp">self</span><span class="o">.</span><span class="n">release</span><span class="p">()</span>
  878. <span class="k">return</span> <span class="n">rv</span>
  879. <span class="k">def</span> <span class="nf">setFormatter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">fmt</span><span class="p">):</span>
  880. <span class="sd">&quot;&quot;&quot;</span>
  881. <span class="sd"> Set the formatter for this handler.</span>
  882. <span class="sd"> &quot;&quot;&quot;</span>
  883. <span class="bp">self</span><span class="o">.</span><span class="n">formatter</span> <span class="o">=</span> <span class="n">fmt</span>
  884. <span class="k">def</span> <span class="nf">flush</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  885. <span class="sd">&quot;&quot;&quot;</span>
  886. <span class="sd"> Ensure all logging output has been flushed.</span>
  887. <span class="sd"> This version does nothing and is intended to be implemented by</span>
  888. <span class="sd"> subclasses.</span>
  889. <span class="sd"> &quot;&quot;&quot;</span>
  890. <span class="k">pass</span>
  891. <span class="k">def</span> <span class="nf">close</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  892. <span class="sd">&quot;&quot;&quot;</span>
  893. <span class="sd"> Tidy up any resources used by the handler.</span>
  894. <span class="sd"> This version removes the handler from an internal map of handlers,</span>
  895. <span class="sd"> _handlers, which is used for handler lookup by name. Subclasses</span>
  896. <span class="sd"> should ensure that this gets called from overridden close()</span>
  897. <span class="sd"> methods.</span>
  898. <span class="sd"> &quot;&quot;&quot;</span>
  899. <span class="c1">#get the module data lock, as we&#39;re updating a shared structure.</span>
  900. <span class="n">_acquireLock</span><span class="p">()</span>
  901. <span class="k">try</span><span class="p">:</span> <span class="c1">#unlikely to raise an exception, but you never know...</span>
  902. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_name</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">_name</span> <span class="ow">in</span> <span class="n">_handlers</span><span class="p">:</span>
  903. <span class="k">del</span> <span class="n">_handlers</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_name</span><span class="p">]</span>
  904. <span class="k">finally</span><span class="p">:</span>
  905. <span class="n">_releaseLock</span><span class="p">()</span>
  906. <span class="k">def</span> <span class="nf">handleError</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
  907. <span class="sd">&quot;&quot;&quot;</span>
  908. <span class="sd"> Handle errors which occur during an emit() call.</span>
  909. <span class="sd"> This method should be called from handlers when an exception is</span>
  910. <span class="sd"> encountered during an emit() call. If raiseExceptions is false,</span>
  911. <span class="sd"> exceptions get silently ignored. This is what is mostly wanted</span>
  912. <span class="sd"> for a logging system - most users will not care about errors in</span>
  913. <span class="sd"> the logging system, they are more interested in application errors.</span>
  914. <span class="sd"> You could, however, replace this with a custom handler if you wish.</span>
  915. <span class="sd"> The record which was being processed is passed in to this method.</span>
  916. <span class="sd"> &quot;&quot;&quot;</span>
  917. <span class="k">if</span> <span class="n">raiseExceptions</span> <span class="ow">and</span> <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span><span class="p">:</span> <span class="c1"># see issue 13807</span>
  918. <span class="n">t</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">tb</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">exc_info</span><span class="p">()</span>
  919. <span class="k">try</span><span class="p">:</span>
  920. <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39;--- Logging error ---</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
  921. <span class="n">traceback</span><span class="o">.</span><span class="n">print_exception</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">tb</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span><span class="p">)</span>
  922. <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39;Call stack:</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
  923. <span class="c1"># Walk the stack frame up until we&#39;re out of logging,</span>
  924. <span class="c1"># so as to print the calling context.</span>
  925. <span class="n">frame</span> <span class="o">=</span> <span class="n">tb</span><span class="o">.</span><span class="n">tb_frame</span>
  926. <span class="k">while</span> <span class="p">(</span><span class="n">frame</span> <span class="ow">and</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="n">frame</span><span class="o">.</span><span class="n">f_code</span><span class="o">.</span><span class="n">co_filename</span><span class="p">)</span> <span class="o">==</span>
  927. <span class="n">__path__</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
  928. <span class="n">frame</span> <span class="o">=</span> <span class="n">frame</span><span class="o">.</span><span class="n">f_back</span>
  929. <span class="k">if</span> <span class="n">frame</span><span class="p">:</span>
  930. <span class="n">traceback</span><span class="o">.</span><span class="n">print_stack</span><span class="p">(</span><span class="n">frame</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">sys</span><span class="o">.</span><span class="n">stderr</span><span class="p">)</span>
  931. <span class="k">else</span><span class="p">:</span>
  932. <span class="c1"># couldn&#39;t find the right stack frame, for some reason</span>
  933. <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39;Logged from file </span><span class="si">%s</span><span class="s1">, line </span><span class="si">%s</span><span class="se">\n</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="p">(</span>
  934. <span class="n">record</span><span class="o">.</span><span class="n">filename</span><span class="p">,</span> <span class="n">record</span><span class="o">.</span><span class="n">lineno</span><span class="p">))</span>
  935. <span class="c1"># Issue 18671: output logging message and arguments</span>
  936. <span class="k">try</span><span class="p">:</span>
  937. <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39;Message: </span><span class="si">%r</span><span class="se">\n</span><span class="s1">&#39;</span>
  938. <span class="s1">&#39;Arguments: </span><span class="si">%s</span><span class="se">\n</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="p">(</span><span class="n">record</span><span class="o">.</span><span class="n">msg</span><span class="p">,</span>
  939. <span class="n">record</span><span class="o">.</span><span class="n">args</span><span class="p">))</span>
  940. <span class="k">except</span> <span class="ne">RecursionError</span><span class="p">:</span> <span class="c1"># See issue 36272</span>
  941. <span class="k">raise</span>
  942. <span class="k">except</span> <span class="ne">Exception</span><span class="p">:</span>
  943. <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39;Unable to print the message and arguments&#39;</span>
  944. <span class="s1">&#39; - possible formatting error.</span><span class="se">\n</span><span class="s1">Use the&#39;</span>
  945. <span class="s1">&#39; traceback above to help find the error.</span><span class="se">\n</span><span class="s1">&#39;</span>
  946. <span class="p">)</span>
  947. <span class="k">except</span> <span class="ne">OSError</span><span class="p">:</span> <span class="c1">#pragma: no cover</span>
  948. <span class="k">pass</span> <span class="c1"># see issue 5971</span>
  949. <span class="k">finally</span><span class="p">:</span>
  950. <span class="k">del</span> <span class="n">t</span><span class="p">,</span> <span class="n">v</span><span class="p">,</span> <span class="n">tb</span>
  951. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  952. <span class="n">level</span> <span class="o">=</span> <span class="n">getLevelName</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">level</span><span class="p">)</span>
  953. <span class="k">return</span> <span class="s1">&#39;&lt;</span><span class="si">%s</span><span class="s1"> (</span><span class="si">%s</span><span class="s1">)&gt;&#39;</span> <span class="o">%</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">level</span><span class="p">)</span>
  954. <span class="k">class</span> <span class="nc">StreamHandler</span><span class="p">(</span><span class="n">Handler</span><span class="p">):</span>
  955. <span class="sd">&quot;&quot;&quot;</span>
  956. <span class="sd"> A handler class which writes logging records, appropriately formatted,</span>
  957. <span class="sd"> to a stream. Note that this class does not close the stream, as</span>
  958. <span class="sd"> sys.stdout or sys.stderr may be used.</span>
  959. <span class="sd"> &quot;&quot;&quot;</span>
  960. <span class="n">terminator</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span>
  961. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">stream</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  962. <span class="sd">&quot;&quot;&quot;</span>
  963. <span class="sd"> Initialize the handler.</span>
  964. <span class="sd"> If stream is not specified, sys.stderr is used.</span>
  965. <span class="sd"> &quot;&quot;&quot;</span>
  966. <span class="n">Handler</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
  967. <span class="k">if</span> <span class="n">stream</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  968. <span class="n">stream</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span>
  969. <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="n">stream</span>
  970. <span class="k">def</span> <span class="nf">flush</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  971. <span class="sd">&quot;&quot;&quot;</span>
  972. <span class="sd"> Flushes the stream.</span>
  973. <span class="sd"> &quot;&quot;&quot;</span>
  974. <span class="bp">self</span><span class="o">.</span><span class="n">acquire</span><span class="p">()</span>
  975. <span class="k">try</span><span class="p">:</span>
  976. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="ow">and</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">,</span> <span class="s2">&quot;flush&quot;</span><span class="p">):</span>
  977. <span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
  978. <span class="k">finally</span><span class="p">:</span>
  979. <span class="bp">self</span><span class="o">.</span><span class="n">release</span><span class="p">()</span>
  980. <span class="k">def</span> <span class="nf">emit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
  981. <span class="sd">&quot;&quot;&quot;</span>
  982. <span class="sd"> Emit a record.</span>
  983. <span class="sd"> If a formatter is specified, it is used to format the record.</span>
  984. <span class="sd"> The record is then written to the stream with a trailing newline. If</span>
  985. <span class="sd"> exception information is present, it is formatted using</span>
  986. <span class="sd"> traceback.print_exception and appended to the stream. If the stream</span>
  987. <span class="sd"> has an &#39;encoding&#39; attribute, it is used to determine how to do the</span>
  988. <span class="sd"> output to the stream.</span>
  989. <span class="sd"> &quot;&quot;&quot;</span>
  990. <span class="k">try</span><span class="p">:</span>
  991. <span class="n">msg</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
  992. <span class="n">stream</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span>
  993. <span class="c1"># issue 35046: merged two stream.writes into one.</span>
  994. <span class="n">stream</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">msg</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">terminator</span><span class="p">)</span>
  995. <span class="bp">self</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
  996. <span class="k">except</span> <span class="ne">RecursionError</span><span class="p">:</span> <span class="c1"># See issue 36272</span>
  997. <span class="k">raise</span>
  998. <span class="k">except</span> <span class="ne">Exception</span><span class="p">:</span>
  999. <span class="bp">self</span><span class="o">.</span><span class="n">handleError</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
  1000. <span class="k">def</span> <span class="nf">setStream</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">stream</span><span class="p">):</span>
  1001. <span class="sd">&quot;&quot;&quot;</span>
  1002. <span class="sd"> Sets the StreamHandler&#39;s stream to the specified value,</span>
  1003. <span class="sd"> if it is different.</span>
  1004. <span class="sd"> Returns the old stream, if the stream was changed, or None</span>
  1005. <span class="sd"> if it wasn&#39;t.</span>
  1006. <span class="sd"> &quot;&quot;&quot;</span>
  1007. <span class="k">if</span> <span class="n">stream</span> <span class="ow">is</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">:</span>
  1008. <span class="n">result</span> <span class="o">=</span> <span class="kc">None</span>
  1009. <span class="k">else</span><span class="p">:</span>
  1010. <span class="n">result</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span>
  1011. <span class="bp">self</span><span class="o">.</span><span class="n">acquire</span><span class="p">()</span>
  1012. <span class="k">try</span><span class="p">:</span>
  1013. <span class="bp">self</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
  1014. <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="n">stream</span>
  1015. <span class="k">finally</span><span class="p">:</span>
  1016. <span class="bp">self</span><span class="o">.</span><span class="n">release</span><span class="p">()</span>
  1017. <span class="k">return</span> <span class="n">result</span>
  1018. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  1019. <span class="n">level</span> <span class="o">=</span> <span class="n">getLevelName</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">level</span><span class="p">)</span>
  1020. <span class="n">name</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">,</span> <span class="s1">&#39;name&#39;</span><span class="p">,</span> <span class="s1">&#39;&#39;</span><span class="p">)</span>
  1021. <span class="c1"># bpo-36015: name can be an int</span>
  1022. <span class="n">name</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
  1023. <span class="k">if</span> <span class="n">name</span><span class="p">:</span>
  1024. <span class="n">name</span> <span class="o">+=</span> <span class="s1">&#39; &#39;</span>
  1025. <span class="k">return</span> <span class="s1">&#39;&lt;</span><span class="si">%s</span><span class="s1"> </span><span class="si">%s</span><span class="s1">(</span><span class="si">%s</span><span class="s1">)&gt;&#39;</span> <span class="o">%</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">level</span><span class="p">)</span>
  1026. <span class="k">class</span> <span class="nc">FileHandler</span><span class="p">(</span><span class="n">StreamHandler</span><span class="p">):</span>
  1027. <span class="sd">&quot;&quot;&quot;</span>
  1028. <span class="sd"> A handler class which writes formatted logging records to disk files.</span>
  1029. <span class="sd"> &quot;&quot;&quot;</span>
  1030. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">filename</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s1">&#39;a&#39;</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">delay</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">errors</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  1031. <span class="sd">&quot;&quot;&quot;</span>
  1032. <span class="sd"> Open the specified file and use it as the stream for logging.</span>
  1033. <span class="sd"> &quot;&quot;&quot;</span>
  1034. <span class="c1"># Issue #27493: add support for Path objects to be passed in</span>
  1035. <span class="n">filename</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">fspath</span><span class="p">(</span><span class="n">filename</span><span class="p">)</span>
  1036. <span class="c1">#keep the absolute path, otherwise derived classes which use this</span>
  1037. <span class="c1">#may come a cropper when the current directory changes</span>
  1038. <span class="bp">self</span><span class="o">.</span><span class="n">baseFilename</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">filename</span><span class="p">)</span>
  1039. <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">=</span> <span class="n">mode</span>
  1040. <span class="bp">self</span><span class="o">.</span><span class="n">encoding</span> <span class="o">=</span> <span class="n">encoding</span>
  1041. <span class="bp">self</span><span class="o">.</span><span class="n">errors</span> <span class="o">=</span> <span class="n">errors</span>
  1042. <span class="bp">self</span><span class="o">.</span><span class="n">delay</span> <span class="o">=</span> <span class="n">delay</span>
  1043. <span class="k">if</span> <span class="n">delay</span><span class="p">:</span>
  1044. <span class="c1">#We don&#39;t open the stream, but we still need to call the</span>
  1045. <span class="c1">#Handler constructor to set level, formatter, lock etc.</span>
  1046. <span class="n">Handler</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
  1047. <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="kc">None</span>
  1048. <span class="k">else</span><span class="p">:</span>
  1049. <span class="n">StreamHandler</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_open</span><span class="p">())</span>
  1050. <span class="k">def</span> <span class="nf">close</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  1051. <span class="sd">&quot;&quot;&quot;</span>
  1052. <span class="sd"> Closes the stream.</span>
  1053. <span class="sd"> &quot;&quot;&quot;</span>
  1054. <span class="bp">self</span><span class="o">.</span><span class="n">acquire</span><span class="p">()</span>
  1055. <span class="k">try</span><span class="p">:</span>
  1056. <span class="k">try</span><span class="p">:</span>
  1057. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">:</span>
  1058. <span class="k">try</span><span class="p">:</span>
  1059. <span class="bp">self</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
  1060. <span class="k">finally</span><span class="p">:</span>
  1061. <span class="n">stream</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span>
  1062. <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="kc">None</span>
  1063. <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">stream</span><span class="p">,</span> <span class="s2">&quot;close&quot;</span><span class="p">):</span>
  1064. <span class="n">stream</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
  1065. <span class="k">finally</span><span class="p">:</span>
  1066. <span class="c1"># Issue #19523: call unconditionally to</span>
  1067. <span class="c1"># prevent a handler leak when delay is set</span>
  1068. <span class="n">StreamHandler</span><span class="o">.</span><span class="n">close</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
  1069. <span class="k">finally</span><span class="p">:</span>
  1070. <span class="bp">self</span><span class="o">.</span><span class="n">release</span><span class="p">()</span>
  1071. <span class="k">def</span> <span class="nf">_open</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  1072. <span class="sd">&quot;&quot;&quot;</span>
  1073. <span class="sd"> Open the current base file with the (original) mode and encoding.</span>
  1074. <span class="sd"> Return the resulting stream.</span>
  1075. <span class="sd"> &quot;&quot;&quot;</span>
  1076. <span class="k">return</span> <span class="nb">open</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">baseFilename</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">encoding</span><span class="p">,</span>
  1077. <span class="n">errors</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">errors</span><span class="p">)</span>
  1078. <span class="k">def</span> <span class="nf">emit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
  1079. <span class="sd">&quot;&quot;&quot;</span>
  1080. <span class="sd"> Emit a record.</span>
  1081. <span class="sd"> If the stream was not opened because &#39;delay&#39; was specified in the</span>
  1082. <span class="sd"> constructor, open it before calling the superclass&#39;s emit.</span>
  1083. <span class="sd"> &quot;&quot;&quot;</span>
  1084. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  1085. <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_open</span><span class="p">()</span>
  1086. <span class="n">StreamHandler</span><span class="o">.</span><span class="n">emit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">)</span>
  1087. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  1088. <span class="n">level</span> <span class="o">=</span> <span class="n">getLevelName</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">level</span><span class="p">)</span>
  1089. <span class="k">return</span> <span class="s1">&#39;&lt;</span><span class="si">%s</span><span class="s1"> </span><span class="si">%s</span><span class="s1"> (</span><span class="si">%s</span><span class="s1">)&gt;&#39;</span> <span class="o">%</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">baseFilename</span><span class="p">,</span> <span class="n">level</span><span class="p">)</span>
  1090. <span class="k">class</span> <span class="nc">_StderrHandler</span><span class="p">(</span><span class="n">StreamHandler</span><span class="p">):</span>
  1091. <span class="sd">&quot;&quot;&quot;</span>
  1092. <span class="sd"> This class is like a StreamHandler using sys.stderr, but always uses</span>
  1093. <span class="sd"> whatever sys.stderr is currently set to rather than the value of</span>
  1094. <span class="sd"> sys.stderr at handler construction time.</span>
  1095. <span class="sd"> &quot;&quot;&quot;</span>
  1096. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="o">=</span><span class="n">NOTSET</span><span class="p">):</span>
  1097. <span class="sd">&quot;&quot;&quot;</span>
  1098. <span class="sd"> Initialize the handler.</span>
  1099. <span class="sd"> &quot;&quot;&quot;</span>
  1100. <span class="n">Handler</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">)</span>
  1101. <span class="nd">@property</span>
  1102. <span class="k">def</span> <span class="nf">stream</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  1103. <span class="k">return</span> <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span>
  1104. <span class="n">_defaultLastResort</span> <span class="o">=</span> <span class="n">_StderrHandler</span><span class="p">(</span><span class="n">WARNING</span><span class="p">)</span>
  1105. <span class="n">lastResort</span> <span class="o">=</span> <span class="n">_defaultLastResort</span>
  1106. <span class="c1">#---------------------------------------------------------------------------</span>
  1107. <span class="c1"># Manager classes and functions</span>
  1108. <span class="c1">#---------------------------------------------------------------------------</span>
  1109. <span class="k">class</span> <span class="nc">PlaceHolder</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
  1110. <span class="sd">&quot;&quot;&quot;</span>
  1111. <span class="sd"> PlaceHolder instances are used in the Manager logger hierarchy to take</span>
  1112. <span class="sd"> the place of nodes for which no loggers have been defined. This class is</span>
  1113. <span class="sd"> intended for internal use only and not as part of the public API.</span>
  1114. <span class="sd"> &quot;&quot;&quot;</span>
  1115. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">alogger</span><span class="p">):</span>
  1116. <span class="sd">&quot;&quot;&quot;</span>
  1117. <span class="sd"> Initialize with the specified logger being a child of this placeholder.</span>
  1118. <span class="sd"> &quot;&quot;&quot;</span>
  1119. <span class="bp">self</span><span class="o">.</span><span class="n">loggerMap</span> <span class="o">=</span> <span class="p">{</span> <span class="n">alogger</span> <span class="p">:</span> <span class="kc">None</span> <span class="p">}</span>
  1120. <span class="k">def</span> <span class="nf">append</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">alogger</span><span class="p">):</span>
  1121. <span class="sd">&quot;&quot;&quot;</span>
  1122. <span class="sd"> Add the specified logger as a child of this placeholder.</span>
  1123. <span class="sd"> &quot;&quot;&quot;</span>
  1124. <span class="k">if</span> <span class="n">alogger</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">loggerMap</span><span class="p">:</span>
  1125. <span class="bp">self</span><span class="o">.</span><span class="n">loggerMap</span><span class="p">[</span><span class="n">alogger</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span>
  1126. <span class="c1">#</span>
  1127. <span class="c1"># Determine which class to use when instantiating loggers.</span>
  1128. <span class="c1">#</span>
  1129. <span class="k">def</span> <span class="nf">setLoggerClass</span><span class="p">(</span><span class="n">klass</span><span class="p">):</span>
  1130. <span class="sd">&quot;&quot;&quot;</span>
  1131. <span class="sd"> Set the class to be used when instantiating a logger. The class should</span>
  1132. <span class="sd"> define __init__() such that only a name argument is required, and the</span>
  1133. <span class="sd"> __init__() should call Logger.__init__()</span>
  1134. <span class="sd"> &quot;&quot;&quot;</span>
  1135. <span class="k">if</span> <span class="n">klass</span> <span class="o">!=</span> <span class="n">Logger</span><span class="p">:</span>
  1136. <span class="k">if</span> <span class="ow">not</span> <span class="nb">issubclass</span><span class="p">(</span><span class="n">klass</span><span class="p">,</span> <span class="n">Logger</span><span class="p">):</span>
  1137. <span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;logger not derived from logging.Logger: &quot;</span>
  1138. <span class="o">+</span> <span class="n">klass</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span>
  1139. <span class="k">global</span> <span class="n">_loggerClass</span>
  1140. <span class="n">_loggerClass</span> <span class="o">=</span> <span class="n">klass</span>
  1141. <span class="k">def</span> <span class="nf">getLoggerClass</span><span class="p">():</span>
  1142. <span class="sd">&quot;&quot;&quot;</span>
  1143. <span class="sd"> Return the class to be used when instantiating a logger.</span>
  1144. <span class="sd"> &quot;&quot;&quot;</span>
  1145. <span class="k">return</span> <span class="n">_loggerClass</span>
  1146. <span class="k">class</span> <span class="nc">Manager</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
  1147. <span class="sd">&quot;&quot;&quot;</span>
  1148. <span class="sd"> There is [under normal circumstances] just one Manager instance, which</span>
  1149. <span class="sd"> holds the hierarchy of loggers.</span>
  1150. <span class="sd"> &quot;&quot;&quot;</span>
  1151. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rootnode</span><span class="p">):</span>
  1152. <span class="sd">&quot;&quot;&quot;</span>
  1153. <span class="sd"> Initialize the manager with the root node of the logger hierarchy.</span>
  1154. <span class="sd"> &quot;&quot;&quot;</span>
  1155. <span class="bp">self</span><span class="o">.</span><span class="n">root</span> <span class="o">=</span> <span class="n">rootnode</span>
  1156. <span class="bp">self</span><span class="o">.</span><span class="n">disable</span> <span class="o">=</span> <span class="mi">0</span>
  1157. <span class="bp">self</span><span class="o">.</span><span class="n">emittedNoHandlerWarning</span> <span class="o">=</span> <span class="kc">False</span>
  1158. <span class="bp">self</span><span class="o">.</span><span class="n">loggerDict</span> <span class="o">=</span> <span class="p">{}</span>
  1159. <span class="bp">self</span><span class="o">.</span><span class="n">loggerClass</span> <span class="o">=</span> <span class="kc">None</span>
  1160. <span class="bp">self</span><span class="o">.</span><span class="n">logRecordFactory</span> <span class="o">=</span> <span class="kc">None</span>
  1161. <span class="nd">@property</span>
  1162. <span class="k">def</span> <span class="nf">disable</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  1163. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_disable</span>
  1164. <span class="nd">@disable</span><span class="o">.</span><span class="n">setter</span>
  1165. <span class="k">def</span> <span class="nf">disable</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span>
  1166. <span class="bp">self</span><span class="o">.</span><span class="n">_disable</span> <span class="o">=</span> <span class="n">_checkLevel</span><span class="p">(</span><span class="n">value</span><span class="p">)</span>
  1167. <span class="k">def</span> <span class="nf">getLogger</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
  1168. <span class="sd">&quot;&quot;&quot;</span>
  1169. <span class="sd"> Get a logger with the specified name (channel name), creating it</span>
  1170. <span class="sd"> if it doesn&#39;t yet exist. This name is a dot-separated hierarchical</span>
  1171. <span class="sd"> name, such as &quot;a&quot;, &quot;a.b&quot;, &quot;a.b.c&quot; or similar.</span>
  1172. <span class="sd"> If a PlaceHolder existed for the specified name [i.e. the logger</span>
  1173. <span class="sd"> didn&#39;t exist but a child of it did], replace it with the created</span>
  1174. <span class="sd"> logger and fix up the parent/child references which pointed to the</span>
  1175. <span class="sd"> placeholder to now point to the logger.</span>
  1176. <span class="sd"> &quot;&quot;&quot;</span>
  1177. <span class="n">rv</span> <span class="o">=</span> <span class="kc">None</span>
  1178. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
  1179. <span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s1">&#39;A logger name must be a string&#39;</span><span class="p">)</span>
  1180. <span class="n">_acquireLock</span><span class="p">()</span>
  1181. <span class="k">try</span><span class="p">:</span>
  1182. <span class="k">if</span> <span class="n">name</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">loggerDict</span><span class="p">:</span>
  1183. <span class="n">rv</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loggerDict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
  1184. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">rv</span><span class="p">,</span> <span class="n">PlaceHolder</span><span class="p">):</span>
  1185. <span class="n">ph</span> <span class="o">=</span> <span class="n">rv</span>
  1186. <span class="n">rv</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loggerClass</span> <span class="ow">or</span> <span class="n">_loggerClass</span><span class="p">)(</span><span class="n">name</span><span class="p">)</span>
  1187. <span class="n">rv</span><span class="o">.</span><span class="n">manager</span> <span class="o">=</span> <span class="bp">self</span>
  1188. <span class="bp">self</span><span class="o">.</span><span class="n">loggerDict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">rv</span>
  1189. <span class="bp">self</span><span class="o">.</span><span class="n">_fixupChildren</span><span class="p">(</span><span class="n">ph</span><span class="p">,</span> <span class="n">rv</span><span class="p">)</span>
  1190. <span class="bp">self</span><span class="o">.</span><span class="n">_fixupParents</span><span class="p">(</span><span class="n">rv</span><span class="p">)</span>
  1191. <span class="k">else</span><span class="p">:</span>
  1192. <span class="n">rv</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loggerClass</span> <span class="ow">or</span> <span class="n">_loggerClass</span><span class="p">)(</span><span class="n">name</span><span class="p">)</span>
  1193. <span class="n">rv</span><span class="o">.</span><span class="n">manager</span> <span class="o">=</span> <span class="bp">self</span>
  1194. <span class="bp">self</span><span class="o">.</span><span class="n">loggerDict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">rv</span>
  1195. <span class="bp">self</span><span class="o">.</span><span class="n">_fixupParents</span><span class="p">(</span><span class="n">rv</span><span class="p">)</span>
  1196. <span class="k">finally</span><span class="p">:</span>
  1197. <span class="n">_releaseLock</span><span class="p">()</span>
  1198. <span class="k">return</span> <span class="n">rv</span>
  1199. <span class="k">def</span> <span class="nf">setLoggerClass</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">klass</span><span class="p">):</span>
  1200. <span class="sd">&quot;&quot;&quot;</span>
  1201. <span class="sd"> Set the class to be used when instantiating a logger with this Manager.</span>
  1202. <span class="sd"> &quot;&quot;&quot;</span>
  1203. <span class="k">if</span> <span class="n">klass</span> <span class="o">!=</span> <span class="n">Logger</span><span class="p">:</span>
  1204. <span class="k">if</span> <span class="ow">not</span> <span class="nb">issubclass</span><span class="p">(</span><span class="n">klass</span><span class="p">,</span> <span class="n">Logger</span><span class="p">):</span>
  1205. <span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;logger not derived from logging.Logger: &quot;</span>
  1206. <span class="o">+</span> <span class="n">klass</span><span class="o">.</span><span class="vm">__name__</span><span class="p">)</span>
  1207. <span class="bp">self</span><span class="o">.</span><span class="n">loggerClass</span> <span class="o">=</span> <span class="n">klass</span>
  1208. <span class="k">def</span> <span class="nf">setLogRecordFactory</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">factory</span><span class="p">):</span>
  1209. <span class="sd">&quot;&quot;&quot;</span>
  1210. <span class="sd"> Set the factory to be used when instantiating a log record with this</span>
  1211. <span class="sd"> Manager.</span>
  1212. <span class="sd"> &quot;&quot;&quot;</span>
  1213. <span class="bp">self</span><span class="o">.</span><span class="n">logRecordFactory</span> <span class="o">=</span> <span class="n">factory</span>
  1214. <span class="k">def</span> <span class="nf">_fixupParents</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">alogger</span><span class="p">):</span>
  1215. <span class="sd">&quot;&quot;&quot;</span>
  1216. <span class="sd"> Ensure that there are either loggers or placeholders all the way</span>
  1217. <span class="sd"> from the specified logger to the root of the logger hierarchy.</span>
  1218. <span class="sd"> &quot;&quot;&quot;</span>
  1219. <span class="n">name</span> <span class="o">=</span> <span class="n">alogger</span><span class="o">.</span><span class="n">name</span>
  1220. <span class="n">i</span> <span class="o">=</span> <span class="n">name</span><span class="o">.</span><span class="n">rfind</span><span class="p">(</span><span class="s2">&quot;.&quot;</span><span class="p">)</span>
  1221. <span class="n">rv</span> <span class="o">=</span> <span class="kc">None</span>
  1222. <span class="k">while</span> <span class="p">(</span><span class="n">i</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">rv</span><span class="p">:</span>
  1223. <span class="n">substr</span> <span class="o">=</span> <span class="n">name</span><span class="p">[:</span><span class="n">i</span><span class="p">]</span>
  1224. <span class="k">if</span> <span class="n">substr</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">loggerDict</span><span class="p">:</span>
  1225. <span class="bp">self</span><span class="o">.</span><span class="n">loggerDict</span><span class="p">[</span><span class="n">substr</span><span class="p">]</span> <span class="o">=</span> <span class="n">PlaceHolder</span><span class="p">(</span><span class="n">alogger</span><span class="p">)</span>
  1226. <span class="k">else</span><span class="p">:</span>
  1227. <span class="n">obj</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loggerDict</span><span class="p">[</span><span class="n">substr</span><span class="p">]</span>
  1228. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">Logger</span><span class="p">):</span>
  1229. <span class="n">rv</span> <span class="o">=</span> <span class="n">obj</span>
  1230. <span class="k">else</span><span class="p">:</span>
  1231. <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">PlaceHolder</span><span class="p">)</span>
  1232. <span class="n">obj</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">alogger</span><span class="p">)</span>
  1233. <span class="n">i</span> <span class="o">=</span> <span class="n">name</span><span class="o">.</span><span class="n">rfind</span><span class="p">(</span><span class="s2">&quot;.&quot;</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">i</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
  1234. <span class="k">if</span> <span class="ow">not</span> <span class="n">rv</span><span class="p">:</span>
  1235. <span class="n">rv</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">root</span>
  1236. <span class="n">alogger</span><span class="o">.</span><span class="n">parent</span> <span class="o">=</span> <span class="n">rv</span>
  1237. <span class="k">def</span> <span class="nf">_fixupChildren</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ph</span><span class="p">,</span> <span class="n">alogger</span><span class="p">):</span>
  1238. <span class="sd">&quot;&quot;&quot;</span>
  1239. <span class="sd"> Ensure that children of the placeholder ph are connected to the</span>
  1240. <span class="sd"> specified logger.</span>
  1241. <span class="sd"> &quot;&quot;&quot;</span>
  1242. <span class="n">name</span> <span class="o">=</span> <span class="n">alogger</span><span class="o">.</span><span class="n">name</span>
  1243. <span class="n">namelen</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
  1244. <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="n">ph</span><span class="o">.</span><span class="n">loggerMap</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
  1245. <span class="c1">#The if means ... if not c.parent.name.startswith(nm)</span>
  1246. <span class="k">if</span> <span class="n">c</span><span class="o">.</span><span class="n">parent</span><span class="o">.</span><span class="n">name</span><span class="p">[:</span><span class="n">namelen</span><span class="p">]</span> <span class="o">!=</span> <span class="n">name</span><span class="p">:</span>
  1247. <span class="n">alogger</span><span class="o">.</span><span class="n">parent</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">parent</span>
  1248. <span class="n">c</span><span class="o">.</span><span class="n">parent</span> <span class="o">=</span> <span class="n">alogger</span>
  1249. <span class="k">def</span> <span class="nf">_clear_cache</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  1250. <span class="sd">&quot;&quot;&quot;</span>
  1251. <span class="sd"> Clear the cache for all loggers in loggerDict</span>
  1252. <span class="sd"> Called when level changes are made</span>
  1253. <span class="sd"> &quot;&quot;&quot;</span>
  1254. <span class="n">_acquireLock</span><span class="p">()</span>
  1255. <span class="k">for</span> <span class="n">logger</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">loggerDict</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
  1256. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">logger</span><span class="p">,</span> <span class="n">Logger</span><span class="p">):</span>
  1257. <span class="n">logger</span><span class="o">.</span><span class="n">_cache</span><span class="o">.</span><span class="n">clear</span><span class="p">()</span>
  1258. <span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="o">.</span><span class="n">_cache</span><span class="o">.</span><span class="n">clear</span><span class="p">()</span>
  1259. <span class="n">_releaseLock</span><span class="p">()</span>
  1260. <span class="c1">#---------------------------------------------------------------------------</span>
  1261. <span class="c1"># Logger classes and functions</span>
  1262. <span class="c1">#---------------------------------------------------------------------------</span>
  1263. <span class="k">class</span> <span class="nc">Logger</span><span class="p">(</span><span class="n">Filterer</span><span class="p">):</span>
  1264. <span class="sd">&quot;&quot;&quot;</span>
  1265. <span class="sd"> Instances of the Logger class represent a single logging channel. A</span>
  1266. <span class="sd"> &quot;logging channel&quot; indicates an area of an application. Exactly how an</span>
  1267. <span class="sd"> &quot;area&quot; is defined is up to the application developer. Since an</span>
  1268. <span class="sd"> application can have any number of areas, logging channels are identified</span>
  1269. <span class="sd"> by a unique string. Application areas can be nested (e.g. an area</span>
  1270. <span class="sd"> of &quot;input processing&quot; might include sub-areas &quot;read CSV files&quot;, &quot;read</span>
  1271. <span class="sd"> XLS files&quot; and &quot;read Gnumeric files&quot;). To cater for this natural nesting,</span>
  1272. <span class="sd"> channel names are organized into a namespace hierarchy where levels are</span>
  1273. <span class="sd"> separated by periods, much like the Java or Python package namespace. So</span>
  1274. <span class="sd"> in the instance given above, channel names might be &quot;input&quot; for the upper</span>
  1275. <span class="sd"> level, and &quot;input.csv&quot;, &quot;input.xls&quot; and &quot;input.gnu&quot; for the sub-levels.</span>
  1276. <span class="sd"> There is no arbitrary limit to the depth of nesting.</span>
  1277. <span class="sd"> &quot;&quot;&quot;</span>
  1278. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">level</span><span class="o">=</span><span class="n">NOTSET</span><span class="p">):</span>
  1279. <span class="sd">&quot;&quot;&quot;</span>
  1280. <span class="sd"> Initialize the logger with a name and an optional level.</span>
  1281. <span class="sd"> &quot;&quot;&quot;</span>
  1282. <span class="n">Filterer</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
  1283. <span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">name</span>
  1284. <span class="bp">self</span><span class="o">.</span><span class="n">level</span> <span class="o">=</span> <span class="n">_checkLevel</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
  1285. <span class="bp">self</span><span class="o">.</span><span class="n">parent</span> <span class="o">=</span> <span class="kc">None</span>
  1286. <span class="bp">self</span><span class="o">.</span><span class="n">propagate</span> <span class="o">=</span> <span class="kc">True</span>
  1287. <span class="bp">self</span><span class="o">.</span><span class="n">handlers</span> <span class="o">=</span> <span class="p">[]</span>
  1288. <span class="bp">self</span><span class="o">.</span><span class="n">disabled</span> <span class="o">=</span> <span class="kc">False</span>
  1289. <span class="bp">self</span><span class="o">.</span><span class="n">_cache</span> <span class="o">=</span> <span class="p">{}</span>
  1290. <span class="k">def</span> <span class="nf">setLevel</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">):</span>
  1291. <span class="sd">&quot;&quot;&quot;</span>
  1292. <span class="sd"> Set the logging level of this logger. level must be an int or a str.</span>
  1293. <span class="sd"> &quot;&quot;&quot;</span>
  1294. <span class="bp">self</span><span class="o">.</span><span class="n">level</span> <span class="o">=</span> <span class="n">_checkLevel</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
  1295. <span class="bp">self</span><span class="o">.</span><span class="n">manager</span><span class="o">.</span><span class="n">_clear_cache</span><span class="p">()</span>
  1296. <span class="k">def</span> <span class="nf">debug</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1297. <span class="sd">&quot;&quot;&quot;</span>
  1298. <span class="sd"> Log &#39;msg % args&#39; with severity &#39;DEBUG&#39;.</span>
  1299. <span class="sd"> To pass exception information, use the keyword argument exc_info with</span>
  1300. <span class="sd"> a true value, e.g.</span>
  1301. <span class="sd"> logger.debug(&quot;Houston, we have a %s&quot;, &quot;thorny problem&quot;, exc_info=1)</span>
  1302. <span class="sd"> &quot;&quot;&quot;</span>
  1303. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isEnabledFor</span><span class="p">(</span><span class="n">DEBUG</span><span class="p">):</span>
  1304. <span class="bp">self</span><span class="o">.</span><span class="n">_log</span><span class="p">(</span><span class="n">DEBUG</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1305. <span class="k">def</span> <span class="nf">info</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1306. <span class="sd">&quot;&quot;&quot;</span>
  1307. <span class="sd"> Log &#39;msg % args&#39; with severity &#39;INFO&#39;.</span>
  1308. <span class="sd"> To pass exception information, use the keyword argument exc_info with</span>
  1309. <span class="sd"> a true value, e.g.</span>
  1310. <span class="sd"> logger.info(&quot;Houston, we have a %s&quot;, &quot;interesting problem&quot;, exc_info=1)</span>
  1311. <span class="sd"> &quot;&quot;&quot;</span>
  1312. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isEnabledFor</span><span class="p">(</span><span class="n">INFO</span><span class="p">):</span>
  1313. <span class="bp">self</span><span class="o">.</span><span class="n">_log</span><span class="p">(</span><span class="n">INFO</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1314. <span class="k">def</span> <span class="nf">warning</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1315. <span class="sd">&quot;&quot;&quot;</span>
  1316. <span class="sd"> Log &#39;msg % args&#39; with severity &#39;WARNING&#39;.</span>
  1317. <span class="sd"> To pass exception information, use the keyword argument exc_info with</span>
  1318. <span class="sd"> a true value, e.g.</span>
  1319. <span class="sd"> logger.warning(&quot;Houston, we have a %s&quot;, &quot;bit of a problem&quot;, exc_info=1)</span>
  1320. <span class="sd"> &quot;&quot;&quot;</span>
  1321. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isEnabledFor</span><span class="p">(</span><span class="n">WARNING</span><span class="p">):</span>
  1322. <span class="bp">self</span><span class="o">.</span><span class="n">_log</span><span class="p">(</span><span class="n">WARNING</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1323. <span class="k">def</span> <span class="nf">warn</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1324. <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">&quot;The &#39;warn&#39; method is deprecated, &quot;</span>
  1325. <span class="s2">&quot;use &#39;warning&#39; instead&quot;</span><span class="p">,</span> <span class="ne">DeprecationWarning</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
  1326. <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1327. <span class="k">def</span> <span class="nf">error</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1328. <span class="sd">&quot;&quot;&quot;</span>
  1329. <span class="sd"> Log &#39;msg % args&#39; with severity &#39;ERROR&#39;.</span>
  1330. <span class="sd"> To pass exception information, use the keyword argument exc_info with</span>
  1331. <span class="sd"> a true value, e.g.</span>
  1332. <span class="sd"> logger.error(&quot;Houston, we have a %s&quot;, &quot;major problem&quot;, exc_info=1)</span>
  1333. <span class="sd"> &quot;&quot;&quot;</span>
  1334. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isEnabledFor</span><span class="p">(</span><span class="n">ERROR</span><span class="p">):</span>
  1335. <span class="bp">self</span><span class="o">.</span><span class="n">_log</span><span class="p">(</span><span class="n">ERROR</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1336. <span class="k">def</span> <span class="nf">exception</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1337. <span class="sd">&quot;&quot;&quot;</span>
  1338. <span class="sd"> Convenience method for logging an ERROR with exception information.</span>
  1339. <span class="sd"> &quot;&quot;&quot;</span>
  1340. <span class="bp">self</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="o">=</span><span class="n">exc_info</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1341. <span class="k">def</span> <span class="nf">critical</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1342. <span class="sd">&quot;&quot;&quot;</span>
  1343. <span class="sd"> Log &#39;msg % args&#39; with severity &#39;CRITICAL&#39;.</span>
  1344. <span class="sd"> To pass exception information, use the keyword argument exc_info with</span>
  1345. <span class="sd"> a true value, e.g.</span>
  1346. <span class="sd"> logger.critical(&quot;Houston, we have a %s&quot;, &quot;major disaster&quot;, exc_info=1)</span>
  1347. <span class="sd"> &quot;&quot;&quot;</span>
  1348. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isEnabledFor</span><span class="p">(</span><span class="n">CRITICAL</span><span class="p">):</span>
  1349. <span class="bp">self</span><span class="o">.</span><span class="n">_log</span><span class="p">(</span><span class="n">CRITICAL</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1350. <span class="n">fatal</span> <span class="o">=</span> <span class="n">critical</span>
  1351. <span class="k">def</span> <span class="nf">log</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1352. <span class="sd">&quot;&quot;&quot;</span>
  1353. <span class="sd"> Log &#39;msg % args&#39; with the integer severity &#39;level&#39;.</span>
  1354. <span class="sd"> To pass exception information, use the keyword argument exc_info with</span>
  1355. <span class="sd"> a true value, e.g.</span>
  1356. <span class="sd"> logger.log(level, &quot;We have a %s&quot;, &quot;mysterious problem&quot;, exc_info=1)</span>
  1357. <span class="sd"> &quot;&quot;&quot;</span>
  1358. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span>
  1359. <span class="k">if</span> <span class="n">raiseExceptions</span><span class="p">:</span>
  1360. <span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;level must be an integer&quot;</span><span class="p">)</span>
  1361. <span class="k">else</span><span class="p">:</span>
  1362. <span class="k">return</span>
  1363. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isEnabledFor</span><span class="p">(</span><span class="n">level</span><span class="p">):</span>
  1364. <span class="bp">self</span><span class="o">.</span><span class="n">_log</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1365. <span class="k">def</span> <span class="nf">findCaller</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">stack_info</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">stacklevel</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
  1366. <span class="sd">&quot;&quot;&quot;</span>
  1367. <span class="sd"> Find the stack frame of the caller so that we can note the source</span>
  1368. <span class="sd"> file name, line number and function name.</span>
  1369. <span class="sd"> &quot;&quot;&quot;</span>
  1370. <span class="n">f</span> <span class="o">=</span> <span class="n">currentframe</span><span class="p">()</span>
  1371. <span class="c1">#On some versions of IronPython, currentframe() returns None if</span>
  1372. <span class="c1">#IronPython isn&#39;t run with -X:Frames.</span>
  1373. <span class="k">if</span> <span class="n">f</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  1374. <span class="n">f</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">f_back</span>
  1375. <span class="n">orig_f</span> <span class="o">=</span> <span class="n">f</span>
  1376. <span class="k">while</span> <span class="n">f</span> <span class="ow">and</span> <span class="n">stacklevel</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
  1377. <span class="n">f</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">f_back</span>
  1378. <span class="n">stacklevel</span> <span class="o">-=</span> <span class="mi">1</span>
  1379. <span class="k">if</span> <span class="ow">not</span> <span class="n">f</span><span class="p">:</span>
  1380. <span class="n">f</span> <span class="o">=</span> <span class="n">orig_f</span>
  1381. <span class="n">rv</span> <span class="o">=</span> <span class="s2">&quot;(unknown file)&quot;</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;(unknown function)&quot;</span><span class="p">,</span> <span class="kc">None</span>
  1382. <span class="k">while</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">f</span><span class="p">,</span> <span class="s2">&quot;f_code&quot;</span><span class="p">):</span>
  1383. <span class="n">co</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">f_code</span>
  1384. <span class="n">filename</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">normcase</span><span class="p">(</span><span class="n">co</span><span class="o">.</span><span class="n">co_filename</span><span class="p">)</span>
  1385. <span class="k">if</span> <span class="n">filename</span> <span class="o">==</span> <span class="n">_srcfile</span><span class="p">:</span>
  1386. <span class="n">f</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">f_back</span>
  1387. <span class="k">continue</span>
  1388. <span class="n">sinfo</span> <span class="o">=</span> <span class="kc">None</span>
  1389. <span class="k">if</span> <span class="n">stack_info</span><span class="p">:</span>
  1390. <span class="n">sio</span> <span class="o">=</span> <span class="n">io</span><span class="o">.</span><span class="n">StringIO</span><span class="p">()</span>
  1391. <span class="n">sio</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39;Stack (most recent call last):</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
  1392. <span class="n">traceback</span><span class="o">.</span><span class="n">print_stack</span><span class="p">(</span><span class="n">f</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="n">sio</span><span class="p">)</span>
  1393. <span class="n">sinfo</span> <span class="o">=</span> <span class="n">sio</span><span class="o">.</span><span class="n">getvalue</span><span class="p">()</span>
  1394. <span class="k">if</span> <span class="n">sinfo</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="s1">&#39;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">:</span>
  1395. <span class="n">sinfo</span> <span class="o">=</span> <span class="n">sinfo</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
  1396. <span class="n">sio</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
  1397. <span class="n">rv</span> <span class="o">=</span> <span class="p">(</span><span class="n">co</span><span class="o">.</span><span class="n">co_filename</span><span class="p">,</span> <span class="n">f</span><span class="o">.</span><span class="n">f_lineno</span><span class="p">,</span> <span class="n">co</span><span class="o">.</span><span class="n">co_name</span><span class="p">,</span> <span class="n">sinfo</span><span class="p">)</span>
  1398. <span class="k">break</span>
  1399. <span class="k">return</span> <span class="n">rv</span>
  1400. <span class="k">def</span> <span class="nf">makeRecord</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">level</span><span class="p">,</span> <span class="n">fn</span><span class="p">,</span> <span class="n">lno</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="p">,</span>
  1401. <span class="n">func</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">extra</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">sinfo</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  1402. <span class="sd">&quot;&quot;&quot;</span>
  1403. <span class="sd"> A factory method which can be overridden in subclasses to create</span>
  1404. <span class="sd"> specialized LogRecords.</span>
  1405. <span class="sd"> &quot;&quot;&quot;</span>
  1406. <span class="n">rv</span> <span class="o">=</span> <span class="n">_logRecordFactory</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">level</span><span class="p">,</span> <span class="n">fn</span><span class="p">,</span> <span class="n">lno</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="p">,</span> <span class="n">func</span><span class="p">,</span>
  1407. <span class="n">sinfo</span><span class="p">)</span>
  1408. <span class="k">if</span> <span class="n">extra</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  1409. <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">extra</span><span class="p">:</span>
  1410. <span class="k">if</span> <span class="p">(</span><span class="n">key</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;message&quot;</span><span class="p">,</span> <span class="s2">&quot;asctime&quot;</span><span class="p">])</span> <span class="ow">or</span> <span class="p">(</span><span class="n">key</span> <span class="ow">in</span> <span class="n">rv</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">):</span>
  1411. <span class="k">raise</span> <span class="ne">KeyError</span><span class="p">(</span><span class="s2">&quot;Attempt to overwrite </span><span class="si">%r</span><span class="s2"> in LogRecord&quot;</span> <span class="o">%</span> <span class="n">key</span><span class="p">)</span>
  1412. <span class="n">rv</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">extra</span><span class="p">[</span><span class="n">key</span><span class="p">]</span>
  1413. <span class="k">return</span> <span class="n">rv</span>
  1414. <span class="k">def</span> <span class="nf">_log</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">extra</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">stack_info</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  1415. <span class="n">stacklevel</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
  1416. <span class="sd">&quot;&quot;&quot;</span>
  1417. <span class="sd"> Low-level logging routine which creates a LogRecord and then calls</span>
  1418. <span class="sd"> all the handlers of this logger to handle the record.</span>
  1419. <span class="sd"> &quot;&quot;&quot;</span>
  1420. <span class="n">sinfo</span> <span class="o">=</span> <span class="kc">None</span>
  1421. <span class="k">if</span> <span class="n">_srcfile</span><span class="p">:</span>
  1422. <span class="c1">#IronPython doesn&#39;t track Python frames, so findCaller raises an</span>
  1423. <span class="c1">#exception on some versions of IronPython. We trap it here so that</span>
  1424. <span class="c1">#IronPython can use logging.</span>
  1425. <span class="k">try</span><span class="p">:</span>
  1426. <span class="n">fn</span><span class="p">,</span> <span class="n">lno</span><span class="p">,</span> <span class="n">func</span><span class="p">,</span> <span class="n">sinfo</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">findCaller</span><span class="p">(</span><span class="n">stack_info</span><span class="p">,</span> <span class="n">stacklevel</span><span class="p">)</span>
  1427. <span class="k">except</span> <span class="ne">ValueError</span><span class="p">:</span> <span class="c1"># pragma: no cover</span>
  1428. <span class="n">fn</span><span class="p">,</span> <span class="n">lno</span><span class="p">,</span> <span class="n">func</span> <span class="o">=</span> <span class="s2">&quot;(unknown file)&quot;</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;(unknown function)&quot;</span>
  1429. <span class="k">else</span><span class="p">:</span> <span class="c1"># pragma: no cover</span>
  1430. <span class="n">fn</span><span class="p">,</span> <span class="n">lno</span><span class="p">,</span> <span class="n">func</span> <span class="o">=</span> <span class="s2">&quot;(unknown file)&quot;</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;(unknown function)&quot;</span>
  1431. <span class="k">if</span> <span class="n">exc_info</span><span class="p">:</span>
  1432. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">exc_info</span><span class="p">,</span> <span class="ne">BaseException</span><span class="p">):</span>
  1433. <span class="n">exc_info</span> <span class="o">=</span> <span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">exc_info</span><span class="p">),</span> <span class="n">exc_info</span><span class="p">,</span> <span class="n">exc_info</span><span class="o">.</span><span class="n">__traceback__</span><span class="p">)</span>
  1434. <span class="k">elif</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">exc_info</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>
  1435. <span class="n">exc_info</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">exc_info</span><span class="p">()</span>
  1436. <span class="n">record</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">makeRecord</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">level</span><span class="p">,</span> <span class="n">fn</span><span class="p">,</span> <span class="n">lno</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span>
  1437. <span class="n">exc_info</span><span class="p">,</span> <span class="n">func</span><span class="p">,</span> <span class="n">extra</span><span class="p">,</span> <span class="n">sinfo</span><span class="p">)</span>
  1438. <span class="bp">self</span><span class="o">.</span><span class="n">handle</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
  1439. <span class="k">def</span> <span class="nf">handle</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
  1440. <span class="sd">&quot;&quot;&quot;</span>
  1441. <span class="sd"> Call the handlers for the specified record.</span>
  1442. <span class="sd"> This method is used for unpickled records received from a socket, as</span>
  1443. <span class="sd"> well as those created locally. Logger-level filtering is applied.</span>
  1444. <span class="sd"> &quot;&quot;&quot;</span>
  1445. <span class="k">if</span> <span class="p">(</span><span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">disabled</span><span class="p">)</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">record</span><span class="p">):</span>
  1446. <span class="bp">self</span><span class="o">.</span><span class="n">callHandlers</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
  1447. <span class="k">def</span> <span class="nf">addHandler</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hdlr</span><span class="p">):</span>
  1448. <span class="sd">&quot;&quot;&quot;</span>
  1449. <span class="sd"> Add the specified handler to this logger.</span>
  1450. <span class="sd"> &quot;&quot;&quot;</span>
  1451. <span class="n">_acquireLock</span><span class="p">()</span>
  1452. <span class="k">try</span><span class="p">:</span>
  1453. <span class="k">if</span> <span class="ow">not</span> <span class="p">(</span><span class="n">hdlr</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">handlers</span><span class="p">):</span>
  1454. <span class="bp">self</span><span class="o">.</span><span class="n">handlers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">hdlr</span><span class="p">)</span>
  1455. <span class="k">finally</span><span class="p">:</span>
  1456. <span class="n">_releaseLock</span><span class="p">()</span>
  1457. <span class="k">def</span> <span class="nf">removeHandler</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hdlr</span><span class="p">):</span>
  1458. <span class="sd">&quot;&quot;&quot;</span>
  1459. <span class="sd"> Remove the specified handler from this logger.</span>
  1460. <span class="sd"> &quot;&quot;&quot;</span>
  1461. <span class="n">_acquireLock</span><span class="p">()</span>
  1462. <span class="k">try</span><span class="p">:</span>
  1463. <span class="k">if</span> <span class="n">hdlr</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">handlers</span><span class="p">:</span>
  1464. <span class="bp">self</span><span class="o">.</span><span class="n">handlers</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="n">hdlr</span><span class="p">)</span>
  1465. <span class="k">finally</span><span class="p">:</span>
  1466. <span class="n">_releaseLock</span><span class="p">()</span>
  1467. <span class="k">def</span> <span class="nf">hasHandlers</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  1468. <span class="sd">&quot;&quot;&quot;</span>
  1469. <span class="sd"> See if this logger has any handlers configured.</span>
  1470. <span class="sd"> Loop through all handlers for this logger and its parents in the</span>
  1471. <span class="sd"> logger hierarchy. Return True if a handler was found, else False.</span>
  1472. <span class="sd"> Stop searching up the hierarchy whenever a logger with the &quot;propagate&quot;</span>
  1473. <span class="sd"> attribute set to zero is found - that will be the last logger which</span>
  1474. <span class="sd"> is checked for the existence of handlers.</span>
  1475. <span class="sd"> &quot;&quot;&quot;</span>
  1476. <span class="n">c</span> <span class="o">=</span> <span class="bp">self</span>
  1477. <span class="n">rv</span> <span class="o">=</span> <span class="kc">False</span>
  1478. <span class="k">while</span> <span class="n">c</span><span class="p">:</span>
  1479. <span class="k">if</span> <span class="n">c</span><span class="o">.</span><span class="n">handlers</span><span class="p">:</span>
  1480. <span class="n">rv</span> <span class="o">=</span> <span class="kc">True</span>
  1481. <span class="k">break</span>
  1482. <span class="k">if</span> <span class="ow">not</span> <span class="n">c</span><span class="o">.</span><span class="n">propagate</span><span class="p">:</span>
  1483. <span class="k">break</span>
  1484. <span class="k">else</span><span class="p">:</span>
  1485. <span class="n">c</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">parent</span>
  1486. <span class="k">return</span> <span class="n">rv</span>
  1487. <span class="k">def</span> <span class="nf">callHandlers</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
  1488. <span class="sd">&quot;&quot;&quot;</span>
  1489. <span class="sd"> Pass a record to all relevant handlers.</span>
  1490. <span class="sd"> Loop through all handlers for this logger and its parents in the</span>
  1491. <span class="sd"> logger hierarchy. If no handler was found, output a one-off error</span>
  1492. <span class="sd"> message to sys.stderr. Stop searching up the hierarchy whenever a</span>
  1493. <span class="sd"> logger with the &quot;propagate&quot; attribute set to zero is found - that</span>
  1494. <span class="sd"> will be the last logger whose handlers are called.</span>
  1495. <span class="sd"> &quot;&quot;&quot;</span>
  1496. <span class="n">c</span> <span class="o">=</span> <span class="bp">self</span>
  1497. <span class="n">found</span> <span class="o">=</span> <span class="mi">0</span>
  1498. <span class="k">while</span> <span class="n">c</span><span class="p">:</span>
  1499. <span class="k">for</span> <span class="n">hdlr</span> <span class="ow">in</span> <span class="n">c</span><span class="o">.</span><span class="n">handlers</span><span class="p">:</span>
  1500. <span class="n">found</span> <span class="o">=</span> <span class="n">found</span> <span class="o">+</span> <span class="mi">1</span>
  1501. <span class="k">if</span> <span class="n">record</span><span class="o">.</span><span class="n">levelno</span> <span class="o">&gt;=</span> <span class="n">hdlr</span><span class="o">.</span><span class="n">level</span><span class="p">:</span>
  1502. <span class="n">hdlr</span><span class="o">.</span><span class="n">handle</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
  1503. <span class="k">if</span> <span class="ow">not</span> <span class="n">c</span><span class="o">.</span><span class="n">propagate</span><span class="p">:</span>
  1504. <span class="n">c</span> <span class="o">=</span> <span class="kc">None</span> <span class="c1">#break out</span>
  1505. <span class="k">else</span><span class="p">:</span>
  1506. <span class="n">c</span> <span class="o">=</span> <span class="n">c</span><span class="o">.</span><span class="n">parent</span>
  1507. <span class="k">if</span> <span class="p">(</span><span class="n">found</span> <span class="o">==</span> <span class="mi">0</span><span class="p">):</span>
  1508. <span class="k">if</span> <span class="n">lastResort</span><span class="p">:</span>
  1509. <span class="k">if</span> <span class="n">record</span><span class="o">.</span><span class="n">levelno</span> <span class="o">&gt;=</span> <span class="n">lastResort</span><span class="o">.</span><span class="n">level</span><span class="p">:</span>
  1510. <span class="n">lastResort</span><span class="o">.</span><span class="n">handle</span><span class="p">(</span><span class="n">record</span><span class="p">)</span>
  1511. <span class="k">elif</span> <span class="n">raiseExceptions</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">manager</span><span class="o">.</span><span class="n">emittedNoHandlerWarning</span><span class="p">:</span>
  1512. <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s2">&quot;No handlers could be found for logger&quot;</span>
  1513. <span class="s2">&quot; </span><span class="se">\&quot;</span><span class="si">%s</span><span class="se">\&quot;\n</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">)</span>
  1514. <span class="bp">self</span><span class="o">.</span><span class="n">manager</span><span class="o">.</span><span class="n">emittedNoHandlerWarning</span> <span class="o">=</span> <span class="kc">True</span>
  1515. <span class="k">def</span> <span class="nf">getEffectiveLevel</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  1516. <span class="sd">&quot;&quot;&quot;</span>
  1517. <span class="sd"> Get the effective level for this logger.</span>
  1518. <span class="sd"> Loop through this logger and its parents in the logger hierarchy,</span>
  1519. <span class="sd"> looking for a non-zero logging level. Return the first one found.</span>
  1520. <span class="sd"> &quot;&quot;&quot;</span>
  1521. <span class="n">logger</span> <span class="o">=</span> <span class="bp">self</span>
  1522. <span class="k">while</span> <span class="n">logger</span><span class="p">:</span>
  1523. <span class="k">if</span> <span class="n">logger</span><span class="o">.</span><span class="n">level</span><span class="p">:</span>
  1524. <span class="k">return</span> <span class="n">logger</span><span class="o">.</span><span class="n">level</span>
  1525. <span class="n">logger</span> <span class="o">=</span> <span class="n">logger</span><span class="o">.</span><span class="n">parent</span>
  1526. <span class="k">return</span> <span class="n">NOTSET</span>
  1527. <span class="k">def</span> <span class="nf">isEnabledFor</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">):</span>
  1528. <span class="sd">&quot;&quot;&quot;</span>
  1529. <span class="sd"> Is this logger enabled for level &#39;level&#39;?</span>
  1530. <span class="sd"> &quot;&quot;&quot;</span>
  1531. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">disabled</span><span class="p">:</span>
  1532. <span class="k">return</span> <span class="kc">False</span>
  1533. <span class="k">try</span><span class="p">:</span>
  1534. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache</span><span class="p">[</span><span class="n">level</span><span class="p">]</span>
  1535. <span class="k">except</span> <span class="ne">KeyError</span><span class="p">:</span>
  1536. <span class="n">_acquireLock</span><span class="p">()</span>
  1537. <span class="k">try</span><span class="p">:</span>
  1538. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">manager</span><span class="o">.</span><span class="n">disable</span> <span class="o">&gt;=</span> <span class="n">level</span><span class="p">:</span>
  1539. <span class="n">is_enabled</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache</span><span class="p">[</span><span class="n">level</span><span class="p">]</span> <span class="o">=</span> <span class="kc">False</span>
  1540. <span class="k">else</span><span class="p">:</span>
  1541. <span class="n">is_enabled</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache</span><span class="p">[</span><span class="n">level</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span>
  1542. <span class="n">level</span> <span class="o">&gt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getEffectiveLevel</span><span class="p">()</span>
  1543. <span class="p">)</span>
  1544. <span class="k">finally</span><span class="p">:</span>
  1545. <span class="n">_releaseLock</span><span class="p">()</span>
  1546. <span class="k">return</span> <span class="n">is_enabled</span>
  1547. <span class="k">def</span> <span class="nf">getChild</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">suffix</span><span class="p">):</span>
  1548. <span class="sd">&quot;&quot;&quot;</span>
  1549. <span class="sd"> Get a logger which is a descendant to this one.</span>
  1550. <span class="sd"> This is a convenience method, such that</span>
  1551. <span class="sd"> logging.getLogger(&#39;abc&#39;).getChild(&#39;def.ghi&#39;)</span>
  1552. <span class="sd"> is the same as</span>
  1553. <span class="sd"> logging.getLogger(&#39;abc.def.ghi&#39;)</span>
  1554. <span class="sd"> It&#39;s useful, for example, when the parent logger is named using</span>
  1555. <span class="sd"> __name__ rather than a literal string.</span>
  1556. <span class="sd"> &quot;&quot;&quot;</span>
  1557. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">root</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">self</span><span class="p">:</span>
  1558. <span class="n">suffix</span> <span class="o">=</span> <span class="s1">&#39;.&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">suffix</span><span class="p">))</span>
  1559. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">manager</span><span class="o">.</span><span class="n">getLogger</span><span class="p">(</span><span class="n">suffix</span><span class="p">)</span>
  1560. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  1561. <span class="n">level</span> <span class="o">=</span> <span class="n">getLevelName</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getEffectiveLevel</span><span class="p">())</span>
  1562. <span class="k">return</span> <span class="s1">&#39;&lt;</span><span class="si">%s</span><span class="s1"> </span><span class="si">%s</span><span class="s1"> (</span><span class="si">%s</span><span class="s1">)&gt;&#39;</span> <span class="o">%</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">level</span><span class="p">)</span>
  1563. <span class="k">def</span> <span class="nf">__reduce__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  1564. <span class="c1"># In general, only the root logger will not be accessible via its name.</span>
  1565. <span class="c1"># However, the root logger&#39;s class has its own __reduce__ method.</span>
  1566. <span class="k">if</span> <span class="n">getLogger</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">self</span><span class="p">:</span>
  1567. <span class="kn">import</span> <span class="nn">pickle</span>
  1568. <span class="k">raise</span> <span class="n">pickle</span><span class="o">.</span><span class="n">PicklingError</span><span class="p">(</span><span class="s1">&#39;logger cannot be pickled&#39;</span><span class="p">)</span>
  1569. <span class="k">return</span> <span class="n">getLogger</span><span class="p">,</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,)</span>
  1570. <span class="k">class</span> <span class="nc">RootLogger</span><span class="p">(</span><span class="n">Logger</span><span class="p">):</span>
  1571. <span class="sd">&quot;&quot;&quot;</span>
  1572. <span class="sd"> A root logger is not that different to any other logger, except that</span>
  1573. <span class="sd"> it must have a logging level and there is only one instance of it in</span>
  1574. <span class="sd"> the hierarchy.</span>
  1575. <span class="sd"> &quot;&quot;&quot;</span>
  1576. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">):</span>
  1577. <span class="sd">&quot;&quot;&quot;</span>
  1578. <span class="sd"> Initialize the logger with the name &quot;root&quot;.</span>
  1579. <span class="sd"> &quot;&quot;&quot;</span>
  1580. <span class="n">Logger</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s2">&quot;root&quot;</span><span class="p">,</span> <span class="n">level</span><span class="p">)</span>
  1581. <span class="k">def</span> <span class="nf">__reduce__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  1582. <span class="k">return</span> <span class="n">getLogger</span><span class="p">,</span> <span class="p">()</span>
  1583. <span class="n">_loggerClass</span> <span class="o">=</span> <span class="n">Logger</span>
  1584. <span class="k">class</span> <span class="nc">LoggerAdapter</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
  1585. <span class="sd">&quot;&quot;&quot;</span>
  1586. <span class="sd"> An adapter for loggers which makes it easier to specify contextual</span>
  1587. <span class="sd"> information in logging output.</span>
  1588. <span class="sd"> &quot;&quot;&quot;</span>
  1589. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logger</span><span class="p">,</span> <span class="n">extra</span><span class="p">):</span>
  1590. <span class="sd">&quot;&quot;&quot;</span>
  1591. <span class="sd"> Initialize the adapter with a logger and a dict-like object which</span>
  1592. <span class="sd"> provides contextual information. This constructor signature allows</span>
  1593. <span class="sd"> easy stacking of LoggerAdapters, if so desired.</span>
  1594. <span class="sd"> You can effectively pass keyword arguments as shown in the</span>
  1595. <span class="sd"> following example:</span>
  1596. <span class="sd"> adapter = LoggerAdapter(someLogger, dict(p1=v1, p2=&quot;v2&quot;))</span>
  1597. <span class="sd"> &quot;&quot;&quot;</span>
  1598. <span class="bp">self</span><span class="o">.</span><span class="n">logger</span> <span class="o">=</span> <span class="n">logger</span>
  1599. <span class="bp">self</span><span class="o">.</span><span class="n">extra</span> <span class="o">=</span> <span class="n">extra</span>
  1600. <span class="k">def</span> <span class="nf">process</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">kwargs</span><span class="p">):</span>
  1601. <span class="sd">&quot;&quot;&quot;</span>
  1602. <span class="sd"> Process the logging message and keyword arguments passed in to</span>
  1603. <span class="sd"> a logging call to insert contextual information. You can either</span>
  1604. <span class="sd"> manipulate the message itself, the keyword args or both. Return</span>
  1605. <span class="sd"> the message and kwargs modified (or not) to suit your needs.</span>
  1606. <span class="sd"> Normally, you&#39;ll only need to override this one method in a</span>
  1607. <span class="sd"> LoggerAdapter subclass for your specific needs.</span>
  1608. <span class="sd"> &quot;&quot;&quot;</span>
  1609. <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;extra&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">extra</span>
  1610. <span class="k">return</span> <span class="n">msg</span><span class="p">,</span> <span class="n">kwargs</span>
  1611. <span class="c1">#</span>
  1612. <span class="c1"># Boilerplate convenience methods</span>
  1613. <span class="c1">#</span>
  1614. <span class="k">def</span> <span class="nf">debug</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1615. <span class="sd">&quot;&quot;&quot;</span>
  1616. <span class="sd"> Delegate a debug call to the underlying logger.</span>
  1617. <span class="sd"> &quot;&quot;&quot;</span>
  1618. <span class="bp">self</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">DEBUG</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1619. <span class="k">def</span> <span class="nf">info</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1620. <span class="sd">&quot;&quot;&quot;</span>
  1621. <span class="sd"> Delegate an info call to the underlying logger.</span>
  1622. <span class="sd"> &quot;&quot;&quot;</span>
  1623. <span class="bp">self</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">INFO</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1624. <span class="k">def</span> <span class="nf">warning</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1625. <span class="sd">&quot;&quot;&quot;</span>
  1626. <span class="sd"> Delegate a warning call to the underlying logger.</span>
  1627. <span class="sd"> &quot;&quot;&quot;</span>
  1628. <span class="bp">self</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">WARNING</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1629. <span class="k">def</span> <span class="nf">warn</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1630. <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">&quot;The &#39;warn&#39; method is deprecated, &quot;</span>
  1631. <span class="s2">&quot;use &#39;warning&#39; instead&quot;</span><span class="p">,</span> <span class="ne">DeprecationWarning</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
  1632. <span class="bp">self</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1633. <span class="k">def</span> <span class="nf">error</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1634. <span class="sd">&quot;&quot;&quot;</span>
  1635. <span class="sd"> Delegate an error call to the underlying logger.</span>
  1636. <span class="sd"> &quot;&quot;&quot;</span>
  1637. <span class="bp">self</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">ERROR</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1638. <span class="k">def</span> <span class="nf">exception</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1639. <span class="sd">&quot;&quot;&quot;</span>
  1640. <span class="sd"> Delegate an exception call to the underlying logger.</span>
  1641. <span class="sd"> &quot;&quot;&quot;</span>
  1642. <span class="bp">self</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">ERROR</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="o">=</span><span class="n">exc_info</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1643. <span class="k">def</span> <span class="nf">critical</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1644. <span class="sd">&quot;&quot;&quot;</span>
  1645. <span class="sd"> Delegate a critical call to the underlying logger.</span>
  1646. <span class="sd"> &quot;&quot;&quot;</span>
  1647. <span class="bp">self</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">CRITICAL</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1648. <span class="k">def</span> <span class="nf">log</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1649. <span class="sd">&quot;&quot;&quot;</span>
  1650. <span class="sd"> Delegate a log call to the underlying logger, after adding</span>
  1651. <span class="sd"> contextual information from this adapter instance.</span>
  1652. <span class="sd"> &quot;&quot;&quot;</span>
  1653. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isEnabledFor</span><span class="p">(</span><span class="n">level</span><span class="p">):</span>
  1654. <span class="n">msg</span><span class="p">,</span> <span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">process</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="n">kwargs</span><span class="p">)</span>
  1655. <span class="bp">self</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1656. <span class="k">def</span> <span class="nf">isEnabledFor</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">):</span>
  1657. <span class="sd">&quot;&quot;&quot;</span>
  1658. <span class="sd"> Is this logger enabled for level &#39;level&#39;?</span>
  1659. <span class="sd"> &quot;&quot;&quot;</span>
  1660. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">isEnabledFor</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
  1661. <span class="k">def</span> <span class="nf">setLevel</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">):</span>
  1662. <span class="sd">&quot;&quot;&quot;</span>
  1663. <span class="sd"> Set the specified level on the underlying logger.</span>
  1664. <span class="sd"> &quot;&quot;&quot;</span>
  1665. <span class="bp">self</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">setLevel</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
  1666. <span class="k">def</span> <span class="nf">getEffectiveLevel</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  1667. <span class="sd">&quot;&quot;&quot;</span>
  1668. <span class="sd"> Get the effective level for the underlying logger.</span>
  1669. <span class="sd"> &quot;&quot;&quot;</span>
  1670. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">getEffectiveLevel</span><span class="p">()</span>
  1671. <span class="k">def</span> <span class="nf">hasHandlers</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  1672. <span class="sd">&quot;&quot;&quot;</span>
  1673. <span class="sd"> See if the underlying logger has any handlers.</span>
  1674. <span class="sd"> &quot;&quot;&quot;</span>
  1675. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">hasHandlers</span><span class="p">()</span>
  1676. <span class="k">def</span> <span class="nf">_log</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">level</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">extra</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">stack_info</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
  1677. <span class="sd">&quot;&quot;&quot;</span>
  1678. <span class="sd"> Low-level log implementation, proxied to allow nested logger adapters.</span>
  1679. <span class="sd"> &quot;&quot;&quot;</span>
  1680. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">_log</span><span class="p">(</span>
  1681. <span class="n">level</span><span class="p">,</span>
  1682. <span class="n">msg</span><span class="p">,</span>
  1683. <span class="n">args</span><span class="p">,</span>
  1684. <span class="n">exc_info</span><span class="o">=</span><span class="n">exc_info</span><span class="p">,</span>
  1685. <span class="n">extra</span><span class="o">=</span><span class="n">extra</span><span class="p">,</span>
  1686. <span class="n">stack_info</span><span class="o">=</span><span class="n">stack_info</span><span class="p">,</span>
  1687. <span class="p">)</span>
  1688. <span class="nd">@property</span>
  1689. <span class="k">def</span> <span class="nf">manager</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  1690. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">manager</span>
  1691. <span class="nd">@manager</span><span class="o">.</span><span class="n">setter</span>
  1692. <span class="k">def</span> <span class="nf">manager</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span>
  1693. <span class="bp">self</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">manager</span> <span class="o">=</span> <span class="n">value</span>
  1694. <span class="nd">@property</span>
  1695. <span class="k">def</span> <span class="nf">name</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  1696. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">name</span>
  1697. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  1698. <span class="n">logger</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">logger</span>
  1699. <span class="n">level</span> <span class="o">=</span> <span class="n">getLevelName</span><span class="p">(</span><span class="n">logger</span><span class="o">.</span><span class="n">getEffectiveLevel</span><span class="p">())</span>
  1700. <span class="k">return</span> <span class="s1">&#39;&lt;</span><span class="si">%s</span><span class="s1"> </span><span class="si">%s</span><span class="s1"> (</span><span class="si">%s</span><span class="s1">)&gt;&#39;</span> <span class="o">%</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">,</span> <span class="n">logger</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">level</span><span class="p">)</span>
  1701. <span class="n">root</span> <span class="o">=</span> <span class="n">RootLogger</span><span class="p">(</span><span class="n">WARNING</span><span class="p">)</span>
  1702. <span class="n">Logger</span><span class="o">.</span><span class="n">root</span> <span class="o">=</span> <span class="n">root</span>
  1703. <span class="n">Logger</span><span class="o">.</span><span class="n">manager</span> <span class="o">=</span> <span class="n">Manager</span><span class="p">(</span><span class="n">Logger</span><span class="o">.</span><span class="n">root</span><span class="p">)</span>
  1704. <span class="c1">#---------------------------------------------------------------------------</span>
  1705. <span class="c1"># Configuration classes and functions</span>
  1706. <span class="c1">#---------------------------------------------------------------------------</span>
  1707. <span class="k">def</span> <span class="nf">basicConfig</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1708. <span class="sd">&quot;&quot;&quot;</span>
  1709. <span class="sd"> Do basic configuration for the logging system.</span>
  1710. <span class="sd"> This function does nothing if the root logger already has handlers</span>
  1711. <span class="sd"> configured, unless the keyword argument *force* is set to ``True``.</span>
  1712. <span class="sd"> It is a convenience method intended for use by simple scripts</span>
  1713. <span class="sd"> to do one-shot configuration of the logging package.</span>
  1714. <span class="sd"> The default behaviour is to create a StreamHandler which writes to</span>
  1715. <span class="sd"> sys.stderr, set a formatter using the BASIC_FORMAT format string, and</span>
  1716. <span class="sd"> add the handler to the root logger.</span>
  1717. <span class="sd"> A number of optional keyword arguments may be specified, which can alter</span>
  1718. <span class="sd"> the default behaviour.</span>
  1719. <span class="sd"> filename Specifies that a FileHandler be created, using the specified</span>
  1720. <span class="sd"> filename, rather than a StreamHandler.</span>
  1721. <span class="sd"> filemode Specifies the mode to open the file, if filename is specified</span>
  1722. <span class="sd"> (if filemode is unspecified, it defaults to &#39;a&#39;).</span>
  1723. <span class="sd"> format Use the specified format string for the handler.</span>
  1724. <span class="sd"> datefmt Use the specified date/time format.</span>
  1725. <span class="sd"> style If a format string is specified, use this to specify the</span>
  1726. <span class="sd"> type of format string (possible values &#39;%&#39;, &#39;{&#39;, &#39;$&#39;, for</span>
  1727. <span class="sd"> %-formatting, :meth:`str.format` and :class:`string.Template`</span>
  1728. <span class="sd"> - defaults to &#39;%&#39;).</span>
  1729. <span class="sd"> level Set the root logger level to the specified level.</span>
  1730. <span class="sd"> stream Use the specified stream to initialize the StreamHandler. Note</span>
  1731. <span class="sd"> that this argument is incompatible with &#39;filename&#39; - if both</span>
  1732. <span class="sd"> are present, &#39;stream&#39; is ignored.</span>
  1733. <span class="sd"> handlers If specified, this should be an iterable of already created</span>
  1734. <span class="sd"> handlers, which will be added to the root handler. Any handler</span>
  1735. <span class="sd"> in the list which does not have a formatter assigned will be</span>
  1736. <span class="sd"> assigned the formatter created in this function.</span>
  1737. <span class="sd"> force If this keyword is specified as true, any existing handlers</span>
  1738. <span class="sd"> attached to the root logger are removed and closed, before</span>
  1739. <span class="sd"> carrying out the configuration as specified by the other</span>
  1740. <span class="sd"> arguments.</span>
  1741. <span class="sd"> encoding If specified together with a filename, this encoding is passed to</span>
  1742. <span class="sd"> the created FileHandler, causing it to be used when the file is</span>
  1743. <span class="sd"> opened.</span>
  1744. <span class="sd"> errors If specified together with a filename, this value is passed to the</span>
  1745. <span class="sd"> created FileHandler, causing it to be used when the file is</span>
  1746. <span class="sd"> opened in text mode. If not specified, the default value is</span>
  1747. <span class="sd"> `backslashreplace`.</span>
  1748. <span class="sd"> Note that you could specify a stream created using open(filename, mode)</span>
  1749. <span class="sd"> rather than passing the filename and mode in. However, it should be</span>
  1750. <span class="sd"> remembered that StreamHandler does not close its stream (since it may be</span>
  1751. <span class="sd"> using sys.stdout or sys.stderr), whereas FileHandler closes its stream</span>
  1752. <span class="sd"> when the handler is closed.</span>
  1753. <span class="sd"> .. versionchanged:: 3.2</span>
  1754. <span class="sd"> Added the ``style`` parameter.</span>
  1755. <span class="sd"> .. versionchanged:: 3.3</span>
  1756. <span class="sd"> Added the ``handlers`` parameter. A ``ValueError`` is now thrown for</span>
  1757. <span class="sd"> incompatible arguments (e.g. ``handlers`` specified together with</span>
  1758. <span class="sd"> ``filename``/``filemode``, or ``filename``/``filemode`` specified</span>
  1759. <span class="sd"> together with ``stream``, or ``handlers`` specified together with</span>
  1760. <span class="sd"> ``stream``.</span>
  1761. <span class="sd"> .. versionchanged:: 3.8</span>
  1762. <span class="sd"> Added the ``force`` parameter.</span>
  1763. <span class="sd"> .. versionchanged:: 3.9</span>
  1764. <span class="sd"> Added the ``encoding`` and ``errors`` parameters.</span>
  1765. <span class="sd"> &quot;&quot;&quot;</span>
  1766. <span class="c1"># Add thread safety in case someone mistakenly calls</span>
  1767. <span class="c1"># basicConfig() from multiple threads</span>
  1768. <span class="n">_acquireLock</span><span class="p">()</span>
  1769. <span class="k">try</span><span class="p">:</span>
  1770. <span class="n">force</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;force&#39;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
  1771. <span class="n">encoding</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;encoding&#39;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
  1772. <span class="n">errors</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;errors&#39;</span><span class="p">,</span> <span class="s1">&#39;backslashreplace&#39;</span><span class="p">)</span>
  1773. <span class="k">if</span> <span class="n">force</span><span class="p">:</span>
  1774. <span class="k">for</span> <span class="n">h</span> <span class="ow">in</span> <span class="n">root</span><span class="o">.</span><span class="n">handlers</span><span class="p">[:]:</span>
  1775. <span class="n">root</span><span class="o">.</span><span class="n">removeHandler</span><span class="p">(</span><span class="n">h</span><span class="p">)</span>
  1776. <span class="n">h</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
  1777. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">root</span><span class="o">.</span><span class="n">handlers</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  1778. <span class="n">handlers</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;handlers&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
  1779. <span class="k">if</span> <span class="n">handlers</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  1780. <span class="k">if</span> <span class="s2">&quot;stream&quot;</span> <span class="ow">in</span> <span class="n">kwargs</span> <span class="ow">and</span> <span class="s2">&quot;filename&quot;</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="p">:</span>
  1781. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;&#39;stream&#39; and &#39;filename&#39; should not be &quot;</span>
  1782. <span class="s2">&quot;specified together&quot;</span><span class="p">)</span>
  1783. <span class="k">else</span><span class="p">:</span>
  1784. <span class="k">if</span> <span class="s2">&quot;stream&quot;</span> <span class="ow">in</span> <span class="n">kwargs</span> <span class="ow">or</span> <span class="s2">&quot;filename&quot;</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="p">:</span>
  1785. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;&#39;stream&#39; or &#39;filename&#39; should not be &quot;</span>
  1786. <span class="s2">&quot;specified together with &#39;handlers&#39;&quot;</span><span class="p">)</span>
  1787. <span class="k">if</span> <span class="n">handlers</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  1788. <span class="n">filename</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;filename&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
  1789. <span class="n">mode</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;filemode&quot;</span><span class="p">,</span> <span class="s1">&#39;a&#39;</span><span class="p">)</span>
  1790. <span class="k">if</span> <span class="n">filename</span><span class="p">:</span>
  1791. <span class="k">if</span> <span class="s1">&#39;b&#39;</span><span class="ow">in</span> <span class="n">mode</span><span class="p">:</span>
  1792. <span class="n">errors</span> <span class="o">=</span> <span class="kc">None</span>
  1793. <span class="n">h</span> <span class="o">=</span> <span class="n">FileHandler</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="n">mode</span><span class="p">,</span>
  1794. <span class="n">encoding</span><span class="o">=</span><span class="n">encoding</span><span class="p">,</span> <span class="n">errors</span><span class="o">=</span><span class="n">errors</span><span class="p">)</span>
  1795. <span class="k">else</span><span class="p">:</span>
  1796. <span class="n">stream</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;stream&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
  1797. <span class="n">h</span> <span class="o">=</span> <span class="n">StreamHandler</span><span class="p">(</span><span class="n">stream</span><span class="p">)</span>
  1798. <span class="n">handlers</span> <span class="o">=</span> <span class="p">[</span><span class="n">h</span><span class="p">]</span>
  1799. <span class="n">dfs</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;datefmt&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
  1800. <span class="n">style</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;style&quot;</span><span class="p">,</span> <span class="s1">&#39;%&#39;</span><span class="p">)</span>
  1801. <span class="k">if</span> <span class="n">style</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">_STYLES</span><span class="p">:</span>
  1802. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;Style must be one of: </span><span class="si">%s</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="s1">&#39;,&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span>
  1803. <span class="n">_STYLES</span><span class="o">.</span><span class="n">keys</span><span class="p">()))</span>
  1804. <span class="n">fs</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;format&quot;</span><span class="p">,</span> <span class="n">_STYLES</span><span class="p">[</span><span class="n">style</span><span class="p">][</span><span class="mi">1</span><span class="p">])</span>
  1805. <span class="n">fmt</span> <span class="o">=</span> <span class="n">Formatter</span><span class="p">(</span><span class="n">fs</span><span class="p">,</span> <span class="n">dfs</span><span class="p">,</span> <span class="n">style</span><span class="p">)</span>
  1806. <span class="k">for</span> <span class="n">h</span> <span class="ow">in</span> <span class="n">handlers</span><span class="p">:</span>
  1807. <span class="k">if</span> <span class="n">h</span><span class="o">.</span><span class="n">formatter</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  1808. <span class="n">h</span><span class="o">.</span><span class="n">setFormatter</span><span class="p">(</span><span class="n">fmt</span><span class="p">)</span>
  1809. <span class="n">root</span><span class="o">.</span><span class="n">addHandler</span><span class="p">(</span><span class="n">h</span><span class="p">)</span>
  1810. <span class="n">level</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;level&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
  1811. <span class="k">if</span> <span class="n">level</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  1812. <span class="n">root</span><span class="o">.</span><span class="n">setLevel</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
  1813. <span class="k">if</span> <span class="n">kwargs</span><span class="p">:</span>
  1814. <span class="n">keys</span> <span class="o">=</span> <span class="s1">&#39;, &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">kwargs</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
  1815. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;Unrecognised argument(s): </span><span class="si">%s</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="n">keys</span><span class="p">)</span>
  1816. <span class="k">finally</span><span class="p">:</span>
  1817. <span class="n">_releaseLock</span><span class="p">()</span>
  1818. <span class="c1">#---------------------------------------------------------------------------</span>
  1819. <span class="c1"># Utility functions at module level.</span>
  1820. <span class="c1"># Basically delegate everything to the root logger.</span>
  1821. <span class="c1">#---------------------------------------------------------------------------</span>
  1822. <span class="k">def</span> <span class="nf">getLogger</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  1823. <span class="sd">&quot;&quot;&quot;</span>
  1824. <span class="sd"> Return a logger with the specified name, creating it if necessary.</span>
  1825. <span class="sd"> If no name is specified, return the root logger.</span>
  1826. <span class="sd"> &quot;&quot;&quot;</span>
  1827. <span class="k">if</span> <span class="ow">not</span> <span class="n">name</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="nb">str</span><span class="p">)</span> <span class="ow">and</span> <span class="n">name</span> <span class="o">==</span> <span class="n">root</span><span class="o">.</span><span class="n">name</span><span class="p">:</span>
  1828. <span class="k">return</span> <span class="n">root</span>
  1829. <span class="k">return</span> <span class="n">Logger</span><span class="o">.</span><span class="n">manager</span><span class="o">.</span><span class="n">getLogger</span><span class="p">(</span><span class="n">name</span><span class="p">)</span>
  1830. <span class="k">def</span> <span class="nf">critical</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1831. <span class="sd">&quot;&quot;&quot;</span>
  1832. <span class="sd"> Log a message with severity &#39;CRITICAL&#39; on the root logger. If the logger</span>
  1833. <span class="sd"> has no handlers, call basicConfig() to add a console handler with a</span>
  1834. <span class="sd"> pre-defined format.</span>
  1835. <span class="sd"> &quot;&quot;&quot;</span>
  1836. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">root</span><span class="o">.</span><span class="n">handlers</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  1837. <span class="n">basicConfig</span><span class="p">()</span>
  1838. <span class="n">root</span><span class="o">.</span><span class="n">critical</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1839. <span class="n">fatal</span> <span class="o">=</span> <span class="n">critical</span>
  1840. <span class="k">def</span> <span class="nf">error</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1841. <span class="sd">&quot;&quot;&quot;</span>
  1842. <span class="sd"> Log a message with severity &#39;ERROR&#39; on the root logger. If the logger has</span>
  1843. <span class="sd"> no handlers, call basicConfig() to add a console handler with a pre-defined</span>
  1844. <span class="sd"> format.</span>
  1845. <span class="sd"> &quot;&quot;&quot;</span>
  1846. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">root</span><span class="o">.</span><span class="n">handlers</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  1847. <span class="n">basicConfig</span><span class="p">()</span>
  1848. <span class="n">root</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1849. <span class="k">def</span> <span class="nf">exception</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1850. <span class="sd">&quot;&quot;&quot;</span>
  1851. <span class="sd"> Log a message with severity &#39;ERROR&#39; on the root logger, with exception</span>
  1852. <span class="sd"> information. If the logger has no handlers, basicConfig() is called to add</span>
  1853. <span class="sd"> a console handler with a pre-defined format.</span>
  1854. <span class="sd"> &quot;&quot;&quot;</span>
  1855. <span class="n">error</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="n">exc_info</span><span class="o">=</span><span class="n">exc_info</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1856. <span class="k">def</span> <span class="nf">warning</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1857. <span class="sd">&quot;&quot;&quot;</span>
  1858. <span class="sd"> Log a message with severity &#39;WARNING&#39; on the root logger. If the logger has</span>
  1859. <span class="sd"> no handlers, call basicConfig() to add a console handler with a pre-defined</span>
  1860. <span class="sd"> format.</span>
  1861. <span class="sd"> &quot;&quot;&quot;</span>
  1862. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">root</span><span class="o">.</span><span class="n">handlers</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  1863. <span class="n">basicConfig</span><span class="p">()</span>
  1864. <span class="n">root</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1865. <span class="k">def</span> <span class="nf">warn</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1866. <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">&quot;The &#39;warn&#39; function is deprecated, &quot;</span>
  1867. <span class="s2">&quot;use &#39;warning&#39; instead&quot;</span><span class="p">,</span> <span class="ne">DeprecationWarning</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
  1868. <span class="n">warning</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1869. <span class="k">def</span> <span class="nf">info</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1870. <span class="sd">&quot;&quot;&quot;</span>
  1871. <span class="sd"> Log a message with severity &#39;INFO&#39; on the root logger. If the logger has</span>
  1872. <span class="sd"> no handlers, call basicConfig() to add a console handler with a pre-defined</span>
  1873. <span class="sd"> format.</span>
  1874. <span class="sd"> &quot;&quot;&quot;</span>
  1875. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">root</span><span class="o">.</span><span class="n">handlers</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  1876. <span class="n">basicConfig</span><span class="p">()</span>
  1877. <span class="n">root</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1878. <span class="k">def</span> <span class="nf">debug</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1879. <span class="sd">&quot;&quot;&quot;</span>
  1880. <span class="sd"> Log a message with severity &#39;DEBUG&#39; on the root logger. If the logger has</span>
  1881. <span class="sd"> no handlers, call basicConfig() to add a console handler with a pre-defined</span>
  1882. <span class="sd"> format.</span>
  1883. <span class="sd"> &quot;&quot;&quot;</span>
  1884. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">root</span><span class="o">.</span><span class="n">handlers</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  1885. <span class="n">basicConfig</span><span class="p">()</span>
  1886. <span class="n">root</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1887. <span class="k">def</span> <span class="nf">log</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  1888. <span class="sd">&quot;&quot;&quot;</span>
  1889. <span class="sd"> Log &#39;msg % args&#39; with the integer severity &#39;level&#39; on the root logger. If</span>
  1890. <span class="sd"> the logger has no handlers, call basicConfig() to add a console handler</span>
  1891. <span class="sd"> with a pre-defined format.</span>
  1892. <span class="sd"> &quot;&quot;&quot;</span>
  1893. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">root</span><span class="o">.</span><span class="n">handlers</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  1894. <span class="n">basicConfig</span><span class="p">()</span>
  1895. <span class="n">root</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">msg</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  1896. <span class="k">def</span> <span class="nf">disable</span><span class="p">(</span><span class="n">level</span><span class="o">=</span><span class="n">CRITICAL</span><span class="p">):</span>
  1897. <span class="sd">&quot;&quot;&quot;</span>
  1898. <span class="sd"> Disable all logging calls of severity &#39;level&#39; and below.</span>
  1899. <span class="sd"> &quot;&quot;&quot;</span>
  1900. <span class="n">root</span><span class="o">.</span><span class="n">manager</span><span class="o">.</span><span class="n">disable</span> <span class="o">=</span> <span class="n">level</span>
  1901. <span class="n">root</span><span class="o">.</span><span class="n">manager</span><span class="o">.</span><span class="n">_clear_cache</span><span class="p">()</span>
  1902. <span class="k">def</span> <span class="nf">shutdown</span><span class="p">(</span><span class="n">handlerList</span><span class="o">=</span><span class="n">_handlerList</span><span class="p">):</span>
  1903. <span class="sd">&quot;&quot;&quot;</span>
  1904. <span class="sd"> Perform any cleanup actions in the logging system (e.g. flushing</span>
  1905. <span class="sd"> buffers).</span>
  1906. <span class="sd"> Should be called at application exit.</span>
  1907. <span class="sd"> &quot;&quot;&quot;</span>
  1908. <span class="k">for</span> <span class="n">wr</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="n">handlerList</span><span class="p">[:]):</span>
  1909. <span class="c1">#errors might occur, for example, if files are locked</span>
  1910. <span class="c1">#we just ignore them if raiseExceptions is not set</span>
  1911. <span class="k">try</span><span class="p">:</span>
  1912. <span class="n">h</span> <span class="o">=</span> <span class="n">wr</span><span class="p">()</span>
  1913. <span class="k">if</span> <span class="n">h</span><span class="p">:</span>
  1914. <span class="k">try</span><span class="p">:</span>
  1915. <span class="n">h</span><span class="o">.</span><span class="n">acquire</span><span class="p">()</span>
  1916. <span class="n">h</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
  1917. <span class="n">h</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
  1918. <span class="k">except</span> <span class="p">(</span><span class="ne">OSError</span><span class="p">,</span> <span class="ne">ValueError</span><span class="p">):</span>
  1919. <span class="c1"># Ignore errors which might be caused</span>
  1920. <span class="c1"># because handlers have been closed but</span>
  1921. <span class="c1"># references to them are still around at</span>
  1922. <span class="c1"># application exit.</span>
  1923. <span class="k">pass</span>
  1924. <span class="k">finally</span><span class="p">:</span>
  1925. <span class="n">h</span><span class="o">.</span><span class="n">release</span><span class="p">()</span>
  1926. <span class="k">except</span><span class="p">:</span> <span class="c1"># ignore everything, as we&#39;re shutting down</span>
  1927. <span class="k">if</span> <span class="n">raiseExceptions</span><span class="p">:</span>
  1928. <span class="k">raise</span>
  1929. <span class="c1">#else, swallow</span>
  1930. <span class="c1">#Let&#39;s try and shutdown automatically on application exit...</span>
  1931. <span class="kn">import</span> <span class="nn">atexit</span>
  1932. <span class="n">atexit</span><span class="o">.</span><span class="n">register</span><span class="p">(</span><span class="n">shutdown</span><span class="p">)</span>
  1933. <span class="c1"># Null handler</span>
  1934. <span class="k">class</span> <span class="nc">NullHandler</span><span class="p">(</span><span class="n">Handler</span><span class="p">):</span>
  1935. <span class="sd">&quot;&quot;&quot;</span>
  1936. <span class="sd"> This handler does nothing. It&#39;s intended to be used to avoid the</span>
  1937. <span class="sd"> &quot;No handlers could be found for logger XXX&quot; one-off warning. This is</span>
  1938. <span class="sd"> important for library code, which may contain code to log events. If a user</span>
  1939. <span class="sd"> of the library does not configure logging, the one-off warning might be</span>
  1940. <span class="sd"> produced; to avoid this, the library developer simply needs to instantiate</span>
  1941. <span class="sd"> a NullHandler and add it to the top-level logger of the library module or</span>
  1942. <span class="sd"> package.</span>
  1943. <span class="sd"> &quot;&quot;&quot;</span>
  1944. <span class="k">def</span> <span class="nf">handle</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
  1945. <span class="sd">&quot;&quot;&quot;Stub.&quot;&quot;&quot;</span>
  1946. <span class="k">def</span> <span class="nf">emit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">record</span><span class="p">):</span>
  1947. <span class="sd">&quot;&quot;&quot;Stub.&quot;&quot;&quot;</span>
  1948. <span class="k">def</span> <span class="nf">createLock</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  1949. <span class="bp">self</span><span class="o">.</span><span class="n">lock</span> <span class="o">=</span> <span class="kc">None</span>
  1950. <span class="k">def</span> <span class="nf">_at_fork_reinit</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  1951. <span class="k">pass</span>
  1952. <span class="c1"># Warnings integration</span>
  1953. <span class="n">_warnings_showwarning</span> <span class="o">=</span> <span class="kc">None</span>
  1954. <span class="k">def</span> <span class="nf">_showwarning</span><span class="p">(</span><span class="n">message</span><span class="p">,</span> <span class="n">category</span><span class="p">,</span> <span class="n">filename</span><span class="p">,</span> <span class="n">lineno</span><span class="p">,</span> <span class="n">file</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">line</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  1955. <span class="sd">&quot;&quot;&quot;</span>
  1956. <span class="sd"> Implementation of showwarnings which redirects to logging, which will first</span>
  1957. <span class="sd"> check to see if the file parameter is None. If a file is specified, it will</span>
  1958. <span class="sd"> delegate to the original warnings implementation of showwarning. Otherwise,</span>
  1959. <span class="sd"> it will call warnings.formatwarning and will log the resulting string to a</span>
  1960. <span class="sd"> warnings logger named &quot;py.warnings&quot; with level logging.WARNING.</span>
  1961. <span class="sd"> &quot;&quot;&quot;</span>
  1962. <span class="k">if</span> <span class="n">file</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  1963. <span class="k">if</span> <span class="n">_warnings_showwarning</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  1964. <span class="n">_warnings_showwarning</span><span class="p">(</span><span class="n">message</span><span class="p">,</span> <span class="n">category</span><span class="p">,</span> <span class="n">filename</span><span class="p">,</span> <span class="n">lineno</span><span class="p">,</span> <span class="n">file</span><span class="p">,</span> <span class="n">line</span><span class="p">)</span>
  1965. <span class="k">else</span><span class="p">:</span>
  1966. <span class="n">s</span> <span class="o">=</span> <span class="n">warnings</span><span class="o">.</span><span class="n">formatwarning</span><span class="p">(</span><span class="n">message</span><span class="p">,</span> <span class="n">category</span><span class="p">,</span> <span class="n">filename</span><span class="p">,</span> <span class="n">lineno</span><span class="p">,</span> <span class="n">line</span><span class="p">)</span>
  1967. <span class="n">logger</span> <span class="o">=</span> <span class="n">getLogger</span><span class="p">(</span><span class="s2">&quot;py.warnings&quot;</span><span class="p">)</span>
  1968. <span class="k">if</span> <span class="ow">not</span> <span class="n">logger</span><span class="o">.</span><span class="n">handlers</span><span class="p">:</span>
  1969. <span class="n">logger</span><span class="o">.</span><span class="n">addHandler</span><span class="p">(</span><span class="n">NullHandler</span><span class="p">())</span>
  1970. <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;</span><span class="si">%s</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">s</span><span class="p">)</span>
  1971. <span class="k">def</span> <span class="nf">captureWarnings</span><span class="p">(</span><span class="n">capture</span><span class="p">):</span>
  1972. <span class="sd">&quot;&quot;&quot;</span>
  1973. <span class="sd"> If capture is true, redirect all warnings to the logging package.</span>
  1974. <span class="sd"> If capture is False, ensure that warnings are not redirected to logging</span>
  1975. <span class="sd"> but to their original destinations.</span>
  1976. <span class="sd"> &quot;&quot;&quot;</span>
  1977. <span class="k">global</span> <span class="n">_warnings_showwarning</span>
  1978. <span class="k">if</span> <span class="n">capture</span><span class="p">:</span>
  1979. <span class="k">if</span> <span class="n">_warnings_showwarning</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  1980. <span class="n">_warnings_showwarning</span> <span class="o">=</span> <span class="n">warnings</span><span class="o">.</span><span class="n">showwarning</span>
  1981. <span class="n">warnings</span><span class="o">.</span><span class="n">showwarning</span> <span class="o">=</span> <span class="n">_showwarning</span>
  1982. <span class="k">else</span><span class="p">:</span>
  1983. <span class="k">if</span> <span class="n">_warnings_showwarning</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  1984. <span class="n">warnings</span><span class="o">.</span><span class="n">showwarning</span> <span class="o">=</span> <span class="n">_warnings_showwarning</span>
  1985. <span class="n">_warnings_showwarning</span> <span class="o">=</span> <span class="kc">None</span>
  1986. </pre></div>
  1987. </div>
  1988. </div>
  1989. <footer>
  1990. <hr/>
  1991. <div role="contentinfo">
  1992. <p>&#169; Copyright 2021, SuperGradients team.</p>
  1993. </div>
  1994. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  1995. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  1996. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  1997. </footer>
  1998. </div>
  1999. </div>
  2000. </section>
  2001. </div>
  2002. <script>
  2003. jQuery(function () {
  2004. SphinxRtdTheme.Navigation.enable(true);
  2005. });
  2006. </script>
  2007. </body>
  2008. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.common.abstractions.abstract_logger &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.common.abstractions.abstract_logger</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.common.abstractions.abstract_logger</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">logging</span>
  84. <span class="kn">import</span> <span class="nn">logging.config</span>
  85. <span class="kn">from</span> <span class="nn">super_gradients.common.auto_logging</span> <span class="kn">import</span> <span class="n">AutoLoggerConfig</span>
  86. <span class="kn">from</span> <span class="nn">super_gradients.common.environment.environment_config</span> <span class="kn">import</span> <span class="n">DEFAULT_LOGGING_LEVEL</span>
  87. <div class="viewcode-block" id="get_logger"><a class="viewcode-back" href="../../../../super_gradients.common.abstractions.html#super_gradients.common.abstractions.abstract_logger.get_logger">[docs]</a><span class="k">def</span> <span class="nf">get_logger</span><span class="p">(</span>
  88. <span class="n">logger_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">training_log_path</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">logs_dir_path</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">log_level</span><span class="o">=</span><span class="n">DEFAULT_LOGGING_LEVEL</span>
  89. <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">logging</span><span class="o">.</span><span class="n">Logger</span><span class="p">:</span>
  90. <span class="n">config_dict</span> <span class="o">=</span> <span class="n">AutoLoggerConfig</span><span class="o">.</span><span class="n">generate_config_for_module_name</span><span class="p">(</span>
  91. <span class="n">module_name</span><span class="o">=</span><span class="n">logger_name</span><span class="p">,</span> <span class="n">training_log_path</span><span class="o">=</span><span class="n">training_log_path</span><span class="p">,</span> <span class="n">logs_dir_path</span><span class="o">=</span><span class="n">logs_dir_path</span><span class="p">,</span> <span class="n">log_level</span><span class="o">=</span><span class="n">log_level</span>
  92. <span class="p">)</span>
  93. <span class="n">logging</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dictConfig</span><span class="p">(</span><span class="n">config_dict</span><span class="p">)</span>
  94. <span class="n">logger</span><span class="p">:</span> <span class="n">logging</span><span class="o">.</span><span class="n">Logger</span> <span class="o">=</span> <span class="n">logging</span><span class="o">.</span><span class="n">getLogger</span><span class="p">(</span><span class="n">logger_name</span><span class="p">)</span>
  95. <span class="k">return</span> <span class="n">logger</span></div>
  96. <div class="viewcode-block" id="ILogger"><a class="viewcode-back" href="../../../../super_gradients.common.abstractions.html#super_gradients.common.abstractions.abstract_logger.ILogger">[docs]</a><span class="k">class</span> <span class="nc">ILogger</span><span class="p">:</span>
  97. <span class="sd">&quot;&quot;&quot;</span>
  98. <span class="sd"> Provides logging capabilities to the derived class.</span>
  99. <span class="sd"> &quot;&quot;&quot;</span>
  100. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logger_name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  101. <span class="n">logger_name</span> <span class="o">=</span> <span class="n">logger_name</span> <span class="k">if</span> <span class="n">logger_name</span> <span class="k">else</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__module__</span><span class="p">)</span>
  102. <span class="bp">self</span><span class="o">.</span><span class="n">_logger</span><span class="p">:</span> <span class="n">logging</span><span class="o">.</span><span class="n">Logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="n">logger_name</span><span class="p">)</span></div>
  103. </pre></div>
  104. </div>
  105. </div>
  106. <footer>
  107. <hr/>
  108. <div role="contentinfo">
  109. <p>&#169; Copyright 2021, SuperGradients team.</p>
  110. </div>
  111. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  112. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  113. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  114. </footer>
  115. </div>
  116. </div>
  117. </section>
  118. </div>
  119. <script>
  120. jQuery(function () {
  121. SphinxRtdTheme.Navigation.enable(true);
  122. });
  123. </script>
  124. </body>
  125. </html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.common.auto_logging.auto_logger &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.common.auto_logging.auto_logger &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -86,116 +88,130 @@
            <div itemprop="articleBody">
            <div itemprop="articleBody">
              
              
   <h1>Source code for super_gradients.common.auto_logging.auto_logger</h1><div class="highlight"><pre>
   <h1>Source code for super_gradients.common.auto_logging.auto_logger</h1><div class="highlight"><pre>
-<span></span><span class="kn">import</span> <span class="nn">json</span>
+<span></span><span class="kn">import</span> <span class="nn">logging</span>
 <span class="kn">import</span> <span class="nn">os</span>
 <span class="kn">import</span> <span class="nn">os</span>
+<span class="kn">import</span> <span class="nn">sys</span>
+<span class="kn">import</span> <span class="nn">time</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span>
 
 
-<span class="kn">import</span> <span class="nn">pkg_resources</span>
-
-<span class="kn">from</span> <span class="nn">super_gradients.common.environment.environment_config</span> <span class="kn">import</span> <span class="n">DEFAULT_LOGGING_LEVEL</span>
 
 
-
-<div class="viewcode-block" id="AutoLoggerConfig"><a class="viewcode-back" href="../../../../super_gradients.common.auto_logging.html#super_gradients.common.auto_logging.auto_logger.AutoLoggerConfig">[docs]</a><span class="k">class</span> <span class="nc">AutoLoggerConfig</span><span class="p">:</span>
+<div class="viewcode-block" id="AutoLoggerConfig"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.AutoLoggerConfig">[docs]</a><span class="k">class</span> <span class="nc">AutoLoggerConfig</span><span class="p">:</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
-<span class="sd">    A Class for the Automated Logging Config that is created from the JSON config file (auto_logging_conf)</span>
+<span class="sd">    A Class for the Automated Logging Config</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 
 
-<div class="viewcode-block" id="AutoLoggerConfig.generate_config_for_module_name"><a class="viewcode-back" href="../../../../super_gradients.common.auto_logging.html#super_gradients.common.auto_logging.auto_logger.AutoLoggerConfig.generate_config_for_module_name">[docs]</a>    <span class="nd">@staticmethod</span>
-    <span class="k">def</span> <span class="nf">generate_config_for_module_name</span><span class="p">(</span>
-        <span class="n">module_name</span><span class="p">,</span>
-        <span class="n">training_log_path</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
-        <span class="n">log_level</span><span class="o">=</span><span class="n">DEFAULT_LOGGING_LEVEL</span><span class="p">,</span>
-        <span class="n">max_bytes</span><span class="o">=</span><span class="mi">10485760</span><span class="p">,</span>
-        <span class="n">logs_dir_path</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
-        <span class="n">handlers_list</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
-    <span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span class="p">:</span>
+    <span class="n">FILE_LOGGING_LEVEL</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;LOG_LEVEL&quot;</span><span class="p">,</span> <span class="s2">&quot;DEBUG&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">upper</span><span class="p">()</span>
+    <span class="n">CONSOLE_LOGGING_LEVEL</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;CONSOLE_LOG_LEVEL&quot;</span><span class="p">,</span> <span class="s2">&quot;INFO&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">upper</span><span class="p">()</span>
+
+    <span class="n">filename</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">filename</span> <span class="o">=</span> <span class="kc">None</span>
+
+    <span class="k">def</span> <span class="nf">_setup_default_logging</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">log_level</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
-<span class="sd">        generate_config_for_module_name - Returns a Config Dict For Logging</span>
-<span class="sd">            :param module_name:     The Python Module name to create auto_logging for</span>
-<span class="sd">            :param log_level:       Minimal log level to set for the new auto_logging</span>
-<span class="sd">            :param max_bytes:       Max size for the log file before rotation starts</span>
-<span class="sd">            :param handlers_list:    A list specifying the handlers (Console, etc..) - Better Leave Empty or None</span>
-<span class="sd">            :param training_log_path: Path to training log file which all modules of super_gradients will write to. Ignored</span>
-<span class="sd">             when set to None.</span>
-<span class="sd">            :param logs_dir_path: Path to sg_logs directory (default=None), where module logs will be saved. When set</span>
-<span class="sd">                to None- module logs will be saved in ~/sg_logs (created if path does not exist). Main use case is for</span>
-<span class="sd">                testing.</span>
-
-
-<span class="sd">            :return: python dict() with the new auto_logging for the module</span>
+<span class="sd">        Setup default logging configuration. Usually happens when app starts, and we don&#39;t have</span>
+<span class="sd">        experiment dir yet.</span>
+<span class="sd">        The default log directory will be `~/sg_logs`</span>
+<span class="sd">        :param log_level: The default log level to use. If None, uses LOG_LEVEL and CONSOLE_LOG_LEVEL environment vars.</span>
+<span class="sd">        :return: None</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 
 
-        <span class="c1"># LOADING THE ORIGINAL ROOT CONFIG FILE</span>
-        <span class="n">conf_file_name</span> <span class="o">=</span> <span class="s2">&quot;auto_logging_conf.json&quot;</span>
-        <span class="n">conf_file_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span>
-            <span class="n">pkg_resources</span><span class="o">.</span><span class="n">resource_filename</span><span class="p">(</span><span class="s2">&quot;super_gradients&quot;</span><span class="p">,</span> <span class="s2">&quot;/common/auto_logging/&quot;</span><span class="p">),</span> <span class="n">conf_file_name</span>
+        <span class="c1"># There is no _easy_ way to log all events to a single file, when using DDP or DataLoader with num_workers &gt; 1</span>
+        <span class="c1"># on Windows platform. In both these cases a multiple processes will be spawned and multiple logs may be created.</span>
+        <span class="c1"># Therefore the log file will have the parent PID to being able to discriminate the logs corresponding to a single run.</span>
+        <span class="n">timestamp</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">strftime</span><span class="p">(</span><span class="s2">&quot;%Y_%m_</span><span class="si">%d</span><span class="s2">_%H_%M_%S&quot;</span><span class="p">,</span> <span class="n">time</span><span class="o">.</span><span class="n">localtime</span><span class="p">())</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_setup_logging</span><span class="p">(</span>
+            <span class="n">filename</span><span class="o">=</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">expanduser</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;~/sg_logs/sg_logs_</span><span class="si">{</span><span class="n">os</span><span class="o">.</span><span class="n">getppid</span><span class="p">()</span><span class="si">}</span><span class="s2">_</span><span class="si">{</spa
+            <span class="n">copy_already_logged_messages</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
+            <span class="n">filemode</span><span class="o">=</span><span class="s2">&quot;w&quot;</span><span class="p">,</span>
+            <span class="n">log_level</span><span class="o">=</span><span class="n">log_level</span><span class="p">,</span>
+        <span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_setup_logging</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">copy_already_logged_messages</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> <span class="n">filemode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span c
+        <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">        Sets the logging configuration to store messages to specific file</span>
+<span class="sd">        :param filename: Output log file</span>
+<span class="sd">        :param filemode: Open mode for file</span>
+<span class="sd">        :param copy_already_logged_messages: Controls whether messages from previous log configuration should be copied</span>
+<span class="sd">               to new place. This is helpful to transfer diagnostic messages (from the app start) to experiment dir.</span>
+<span class="sd">        :param log_level: The default log level to use. If None, uses LOG_LEVEL and CONSOLE_LOG_LEVEL environment vars.</span>
+<span class="sd">        :return:</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+        <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="n">filename</span><span class="p">),</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+
+        <span class="k">if</span> <span class="n">copy_already_logged_messages</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">filename</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="bp">self</span><span cl
+            <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filename</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">src</span><span class="p">:</span>
+                <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="s2">&quot;w&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">dst</span><span class="p">:</span>
+                    <span class="n">dst</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">src</span><span class="o">.</span><span class="n">read</span><span class="p">())</span>
+
+        <span class="n">file_logging_level</span> <span class="o">=</span> <span class="n">log_level</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">FILE_LOGGING_LEVEL</span>
+        <span class="n">console_logging_level</span> <span class="o">=</span> <span class="n">log_level</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">CONSOLE_LOGGING_LEVEL</span>
+
+        <span class="n">cur_version</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">version_info</span>
+        <span class="n">python_38</span> <span class="o">=</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">8</span><span class="p">)</span>
+        <span class="n">python_39</span> <span class="o">=</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">9</span><span class="p">)</span>
+        <span class="n">manager</span> <span class="o">=</span> <span class="n">logging</span><span class="o">.</span><span class="n">getLogger</span><span class="p">(</span><span class="s2">&quot;&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">manager</span>
+
+        <span class="n">extra_kwargs</span> <span class="o">=</span> <span class="p">{}</span>
+        <span class="k">if</span> <span class="n">cur_version</span> <span class="o">&gt;=</span> <span class="n">python_38</span><span class="p">:</span>
+            <span class="n">extra_kwargs</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span>
+                <span class="n">force</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
+            <span class="p">)</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="c1"># If the logging does not support force=True, we should manually delete handlers</span>
+            <span class="k">del</span> <span class="n">manager</span><span class="o">.</span><span class="n">root</span><span class="o">.</span><span class="n">handlers</span><span class="p">[:]</span>
+
+        <span class="k">if</span> <span class="n">cur_version</span> <span class="o">&gt;=</span> <span class="n">python_39</span><span class="p">:</span>
+            <span class="n">extra_kwargs</span><span class="p">[</span><span class="s2">&quot;encoding&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="s2">&quot;utf-8&quot;</span>
+
+        <span class="n">logging</span><span class="o">.</span><span class="n">basicConfig</span><span class="p">(</span>
+            <span class="n">filename</span><span class="o">=</span><span class="n">filename</span><span class="p">,</span>
+            <span class="n">filemode</span><span class="o">=</span><span class="n">filemode</span><span class="p">,</span>
+            <span class="nb">format</span><span class="o">=</span><span class="s2">&quot;</span><span class="si">%(asctime)s</span><span class="s2"> </span><span class="si">%(levelname)s</span><span class="s2"> - </span><span class="si">%(name)s</span><span class="s2"> - </span><span class="si">%(message)s</span><span class="s2">&quot;</span><span class="p">,</span>
+            <span class="n">datefmt</span><span class="o">=</span><span class="s2">&quot;[%Y-%m-</span><span class="si">%d</span><span class="s2"> %H:%M:%S]&quot;</span><span class="p">,</span>
+            <span class="n">level</span><span class="o">=</span><span class="n">file_logging_level</span><span class="p">,</span>
+            <span class="o">**</span><span class="n">extra_kwargs</span><span class="p">,</span>
         <span class="p">)</span>
         <span class="p">)</span>
 
 
-        <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">conf_file_path</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">logging_configuration_file</span><span class="p">:</span>
-            <span class="n">config_dict</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">logging_configuration_file</span><span class="p">)</span>
-
-        <span class="c1"># CREATING THE PATH TO THE &quot;HOME&quot; FOLDER WITH THE LOG FILE NAME</span>
-        <span class="k">if</span> <span class="ow">not</span> <span class="n">logs_dir_path</span><span class="p">:</span>
-            <span class="n">log_file_name</span> <span class="o">=</span> <span class="n">module_name</span> <span class="o">+</span> <span class="s2">&quot;.log&quot;</span>
-            <span class="n">user_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">expanduser</span><span class="p">(</span><span class="sa">r</span><span class="s2">&quot;~&quot;</span><span class="p">)</span>
-            <span class="n">logs_dir_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">user_dir</span><span class="p">,</span> <span class="s2">&quot;sg_logs&quot;</span><span class="p">)</span>
-
-        <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">logs_dir_path</span><span class="p">):</span>
-            <span class="k">try</span><span class="p">:</span>
-                <span class="n">os</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">logs_dir_path</span><span class="p">)</span>
-            <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>
-                <span class="nb">print</span><span class="p">(</span>
-                    <span class="s2">&quot;[WARNING] - sg_logs folder was not found and couldn&#39;t be created from the code - &quot;</span>
-                    <span class="s2">&quot;All of the Log output will be sent to Console!&quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">ex</span><span class="p">)</span>
-                <span class="p">)</span>
-
-            <span class="c1"># HANDLERS LIST IS EMPTY AS CONSOLE IS ONLY ROOT HANDLER BECAUSE MODULE LOGGERS PROPAGATE THEIR LOGS UP.</span>
-            <span class="n">handlers_list</span> <span class="o">=</span> <span class="p">[]</span>
-            <span class="n">logger</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;level&quot;</span><span class="p">:</span> <span class="n">log_level</span><span class="p">,</span> <span class="s2">&quot;handlers&quot;</span><span class="p">:</span> <span class="n">handlers_list</span><span class="p">,</span> <span class="s2">&quot;propagate&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">}</span>
-            <span class="n">config_dict</span><span class="p">[</span><span class="s2">&quot;loggers&quot;</span><span class="p">][</span><span class="n">module_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">logger</span>
-
-            <span class="k">return</span> <span class="n">config_dict</span>
-
-        <span class="n">log_file_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">logs_dir_path</span><span class="p">,</span> <span class="n">log_file_name</span><span class="p">)</span>
-
-        <span class="c1"># THE ENTRIES TO ADD TO THE ORIGINAL CONFIGURATION</span>
-        <span class="n">handler_name</span> <span class="o">=</span> <span class="n">module_name</span> <span class="o">+</span> <span class="s2">&quot;_file_handler&quot;</span>
-        <span class="n">file_handler</span> <span class="o">=</span> <span class="p">{</span>
-            <span class="s2">&quot;class&quot;</span><span class="p">:</span> <span class="s2">&quot;logging.handlers.RotatingFileHandler&quot;</span><span class="p">,</span>
-            <span class="s2">&quot;level&quot;</span><span class="p">:</span> <span class="n">log_level</span><span class="p">,</span>
-            <span class="s2">&quot;formatter&quot;</span><span class="p">:</span> <span class="s2">&quot;fileFormatter&quot;</span><span class="p">,</span>
-            <span class="s2">&quot;filename&quot;</span><span class="p">:</span> <span class="n">log_file_path</span><span class="p">,</span>
-            <span class="s2">&quot;maxBytes&quot;</span><span class="p">:</span> <span class="n">max_bytes</span><span class="p">,</span>
-            <span class="s2">&quot;backupCount&quot;</span><span class="p">:</span> <span class="mi">20</span><span class="p">,</span>
-            <span class="s2">&quot;encoding&quot;</span><span class="p">:</span> <span class="s2">&quot;utf8&quot;</span><span class="p">,</span>
-        <span class="p">}</span>
-
-        <span class="c1"># CREATING ONLY A FILE HANDLER, CONSOLE IS ONLY ROOT HANDLER AS MODULE LOGGERS PROPAGATE THEIR LOGS UP.</span>
-        <span class="k">if</span> <span class="n">handlers_list</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="n">handlers_list</span><span class="o">.</span><span class="n">empty</span><span class="p">():</span>
-            <span class="n">handlers_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">handler_name</span><span class="p">]</span>
-
-        <span class="n">logger</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;level&quot;</span><span class="p">:</span> <span class="n">log_level</span><span class="p">,</span> <span class="s2">&quot;handlers&quot;</span><span class="p">:</span> <span class="n">handlers_list</span><span class="p">,</span> <span class="s2">&quot;propagate&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">}</span>
-
-        <span class="c1"># ADDING THE NEW LOGGER ENTRIES TO THE CONFIG DICT</span>
-        <span class="n">config_dict</span><span class="p">[</span><span class="s2">&quot;handlers&quot;</span><span class="p">][</span><span class="n">handler_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">file_handler</span>
-        <span class="n">config_dict</span><span class="p">[</span><span class="s2">&quot;loggers&quot;</span><span class="p">][</span><span class="n">module_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">logger</span>
-        <span class="n">config_dict</span><span class="p">[</span><span class="s2">&quot;root&quot;</span><span class="p">][</span><span class="s2">&quot;handlers&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">handler_name</span><span class="p">)</span>
-
-        <span class="k">if</span> <span class="n">training_log_path</span><span class="p">:</span>
-            <span class="n">training_file_handler</span> <span class="o">=</span> <span class="p">{</span>
-                <span class="s2">&quot;class&quot;</span><span class="p">:</span> <span class="s2">&quot;logging.handlers.RotatingFileHandler&quot;</span><span class="p">,</span>
-                <span class="s2">&quot;level&quot;</span><span class="p">:</span> <span class="n">log_level</span><span class="p">,</span>
-                <span class="s2">&quot;formatter&quot;</span><span class="p">:</span> <span class="s2">&quot;fileFormatter&quot;</span><span class="p">,</span>
-                <span class="s2">&quot;filename&quot;</span><span class="p">:</span> <span class="n">training_log_path</span><span class="p">,</span>
-                <span class="s2">&quot;maxBytes&quot;</span><span class="p">:</span> <span class="n">max_bytes</span><span class="p">,</span>
-                <span class="s2">&quot;backupCount&quot;</span><span class="p">:</span> <span class="mi">20</span><span class="p">,</span>
-                <span class="s2">&quot;encoding&quot;</span><span class="p">:</span> <span class="s2">&quot;utf8&quot;</span><span class="p">,</span>
-            <span class="p">}</span>
-
-            <span class="c1"># ALL OF DECI_TRAINER MODULES LOGGERS PROPAGATE UP TO THE ROOT SO THE ADD TRAIN FILE HANDLER FOR THE ROOT.</span>
-            <span class="n">config_dict</span><span class="p">[</span><span class="s2">&quot;handlers&quot;</span><span class="p">][</span><span class="s2">&quot;training&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">training_file_handler</span>
-            <span class="n">config_dict</span><span class="p">[</span><span class="s2">&quot;root&quot;</span><span class="p">][</span><span class="s2">&quot;handlers&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s2">&quot;training&quot;</span><span class="p">)</span>
-
-        <span class="k">return</span> <span class="n">config_dict</span></div></div>
+        <span class="c1"># Add console handler</span>
+        <span class="n">console_handler</span> <span class="o">=</span> <span class="n">logging</span><span class="o">.</span><span class="n">StreamHandler</span><span class="p">()</span>
+        <span class="n">console_handler</span><span class="o">.</span><span class="n">setLevel</span><span class="p">(</span><span class="n">console_logging_level</span><span class="p">)</span>
+        <span class="n">console_handler</span><span class="o">.</span><span class="n">setFormatter</span><span class="p">(</span>
+            <span class="n">logging</span><span class="o">.</span><span class="n">Formatter</span><span class="p">(</span>
+                <span class="s2">&quot;</span><span class="si">%(asctime)s</span><span class="s2"> </span><span class="si">%(levelname)s</span><span class="s2"> - </span><span class="si">%(filename)s</span><span class="s2"> - </span><span class="si">%(message)s</span><span class="s2">&quot;</span><span class="p">,</span>
+                <span class="n">datefmt</span><span class="o">=</span><span class="s2">&quot;[%Y-%m-</span><span class="si">%d</span><span class="s2"> %H:%M:%S]&quot;</span><span class="p">,</span>
+            <span class="p">)</span>
+        <span class="p">)</span>
+        <span class="n">manager</span><span class="o">.</span><span class="n">root</span><span class="o">.</span><span class="n">handlers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">console_handler</span><span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">filename</span> <span class="o">=</span> <span class="n">filename</span>
+
+<div class="viewcode-block" id="AutoLoggerConfig.get_instance"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.AutoLoggerConfig.get_instance">[docs]</a>    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">get_instance</span><span class="p">(</span><span class="bp">cls</span><span class="p">):</span>
+        <span class="k">global</span> <span class="n">_super_gradients_logger_config</span>
+        <span class="k">if</span> <span class="n">_super_gradients_logger_config</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">_super_gradients_logger_config</span> <span class="o">=</span> <span class="bp">cls</span><span class="p">()</span>
+            <span class="n">_super_gradients_logger_config</span><span class="o">.</span><span class="n">_setup_default_logging</span><span class="p">()</span>
+
+        <span class="k">return</span> <span class="n">_super_gradients_logger_config</span></div>
+
+<div class="viewcode-block" id="AutoLoggerConfig.get_log_file_path"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.AutoLoggerConfig.get_log_file_path">[docs]</a>    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">get_log_file_path</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
+        <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">        Return the current log file used to store log messages</span>
+<span class="sd">        :return: Full path to log file</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+        <span class="bp">self</span> <span class="o">=</span> <span class="bp">cls</span><span class="o">.</span><span class="n">get_instance</span><span class="p">()</span>
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">filename</span></div>
+
+<div class="viewcode-block" id="AutoLoggerConfig.setup_logging"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.AutoLoggerConfig.setup_logging">[docs]</a>    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">setup_logging</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">copy_already_logged_messages</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> <span class="n">filemode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span cla
+        <span class="bp">self</span> <span class="o">=</span> <span class="bp">cls</span><span class="o">.</span><span class="n">get_instance</span><span class="p">()</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_setup_logging</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="n">copy_already_logged_messages</span><span class="p">,</span> <span class="n">filemode</span><span class="p">,</span> <span class="n">log_level</span><span class="p">)</span></div></div>
+
+
+<span class="n">_super_gradients_logger_config</span> <span class="o">=</span> <span class="kc">None</span>
 </pre></div>
 </pre></div>
 
 
            </div>
            </div>
@@ -225,4 +241,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.common.auto_logging.console_logging &mdash; SuperGradients 3.0.3 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
  11. <!--[if lt IE 9]>
  12. <script src="../../../../_static/js/html5shiv.min.js"></script>
  13. <![endif]-->
  14. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  15. <script src="../../../../_static/jquery.js"></script>
  16. <script src="../../../../_static/underscore.js"></script>
  17. <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
  18. <script src="../../../../_static/doctools.js"></script>
  19. <script src="../../../../_static/sphinx_highlight.js"></script>
  20. <script src="../../../../_static/js/theme.js"></script>
  21. <link rel="index" title="Index" href="../../../../genindex.html" />
  22. <link rel="search" title="Search" href="../../../../search.html" />
  23. </head>
  24. <body class="wy-body-for-nav">
  25. <div class="wy-grid-for-nav">
  26. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  27. <div class="wy-side-scroll">
  28. <div class="wy-side-nav-search" >
  29. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  30. </a>
  31. <div role="search">
  32. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  33. <input type="text" name="q" placeholder="Search docs" />
  34. <input type="hidden" name="check_keywords" value="yes" />
  35. <input type="hidden" name="area" value="default" />
  36. </form>
  37. </div>
  38. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  39. <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
  40. <ul>
  41. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
  42. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
  45. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
  46. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
  47. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
  57. </ul>
  58. <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
  59. <ul>
  60. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  61. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  62. </ul>
  63. </div>
  64. </div>
  65. </nav>
  66. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  67. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  68. <a href="../../../../index.html">SuperGradients</a>
  69. </nav>
  70. <div class="wy-nav-content">
  71. <div class="rst-content">
  72. <div role="navigation" aria-label="Page navigation">
  73. <ul class="wy-breadcrumbs">
  74. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  75. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  76. <li>super_gradients.common.auto_logging.console_logging</li>
  77. <li class="wy-breadcrumbs-aside">
  78. </li>
  79. </ul>
  80. <hr/>
  81. </div>
  82. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  83. <div itemprop="articleBody">
  84. <h1>Source code for super_gradients.common.auto_logging.console_logging</h1><div class="highlight"><pre>
  85. <span></span><span class="kn">import</span> <span class="nn">os</span>
  86. <span class="kn">import</span> <span class="nn">sys</span>
  87. <span class="kn">from</span> <span class="nn">datetime</span> <span class="kn">import</span> <span class="n">datetime</span>
  88. <span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
  89. <span class="kn">from</span> <span class="nn">io</span> <span class="kn">import</span> <span class="n">StringIO</span>
  90. <span class="kn">import</span> <span class="nn">atexit</span>
  91. <span class="kn">from</span> <span class="nn">threading</span> <span class="kn">import</span> <span class="n">Lock</span>
  92. <span class="kn">from</span> <span class="nn">super_gradients.common.environment.env_helpers</span> <span class="kn">import</span> <span class="n">multi_process_safe</span><span class="p">,</span> <span class="n">is_main_process</span>
  93. <span class="k">class</span> <span class="nc">BufferWriter</span><span class="p">:</span>
  94. <span class="sd">&quot;&quot;&quot;File writer buffer that opens a file only when flushing and under the condition that threshold buffersize was reached.&quot;&quot;&quot;</span>
  95. <span class="n">FILE_BUFFER_SIZE</span> <span class="o">=</span> <span class="mi">10_000</span> <span class="c1"># Number of chars to be buffered before writing the buffer on disk.</span>
  96. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">buffer</span><span class="p">:</span> <span class="n">StringIO</span><span class="p">,</span> <span class="n">buffer_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="n">Lock</span><span class="p">):</span>
  97. <span class="sd">&quot;&quot;&quot;</span>
  98. <span class="sd"> :param filename: Name of the file where to write the bugger</span>
  99. <span class="sd"> :param buffer: Buffer object</span>
  100. <span class="sd"> :param buffer_size: Number of chars to be buffered before writing the buffer on disk.</span>
  101. <span class="sd"> :param lock: Thread lock to prevent multiple threads to write at the same time</span>
  102. <span class="sd"> &quot;&quot;&quot;</span>
  103. <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span> <span class="o">=</span> <span class="n">buffer</span>
  104. <span class="bp">self</span><span class="o">.</span><span class="n">filename</span> <span class="o">=</span> <span class="n">filename</span>
  105. <span class="bp">self</span><span class="o">.</span><span class="n">buffer_size</span> <span class="o">=</span> <span class="n">buffer_size</span>
  106. <span class="bp">self</span><span class="o">.</span><span class="n">lock</span> <span class="o">=</span> <span class="n">lock</span>
  107. <span class="k">def</span> <span class="nf">write</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
  108. <span class="sd">&quot;&quot;&quot;Write to buffer (not on disk).&quot;&quot;&quot;</span>
  109. <span class="k">with</span> <span class="bp">self</span><span class="o">.</span><span class="n">lock</span><span class="p">:</span>
  110. <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
  111. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_require_flush</span><span class="p">():</span>
  112. <span class="bp">self</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
  113. <span class="k">def</span> <span class="nf">flush</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">force</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
  114. <span class="sd">&quot;&quot;&quot;Write the buffer on disk if relevant.&quot;&quot;&quot;</span>
  115. <span class="k">if</span> <span class="n">force</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">_require_flush</span><span class="p">():</span>
  116. <span class="k">with</span> <span class="bp">self</span><span class="o">.</span><span class="n">lock</span><span class="p">:</span>
  117. <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filename</span><span class="p">),</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  118. <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filename</span><span class="p">,</span> <span class="s2">&quot;a&quot;</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
  119. <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">getvalue</span><span class="p">())</span>
  120. <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">truncate</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
  121. <span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">seek</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
  122. <span class="k">def</span> <span class="nf">_require_flush</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
  123. <span class="sd">&quot;&quot;&quot;Indicate if a buffer is needed (i.e. if buffer size above threshold)&quot;&quot;&quot;</span>
  124. <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">getvalue</span><span class="p">())</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">buffer_size</span>
  125. <span class="k">class</span> <span class="nc">StderrTee</span><span class="p">(</span><span class="n">BufferWriter</span><span class="p">):</span>
  126. <span class="sd">&quot;&quot;&quot;Duplicate the stderr stream to save it into a given file.&quot;&quot;&quot;</span>
  127. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">buffer</span><span class="p">:</span> <span class="n">StringIO</span><span class="p">,</span> <span class="n">buffer_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="n">Lock</span><span class="p">):</span>
  128. <span class="sd">&quot;&quot;&quot;</span>
  129. <span class="sd"> :param filename: Name of the file where to write the bugger</span>
  130. <span class="sd"> :param buffer: Buffer object</span>
  131. <span class="sd"> :param buffer_size: Number of chars to be buffered before writing the buffer on disk.</span>
  132. <span class="sd"> :param lock: Thread lock to prevent multiple threads to write at the same time</span>
  133. <span class="sd"> &quot;&quot;&quot;</span>
  134. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="n">buffer</span><span class="p">,</span> <span class="n">buffer_size</span><span class="p">,</span> <span class="n">lock</span><span class="p">)</span>
  135. <span class="bp">self</span><span class="o">.</span><span class="n">stderr</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span>
  136. <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span> <span class="o">=</span> <span class="bp">self</span>
  137. <span class="k">def</span> <span class="fm">__del__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  138. <span class="n">sys</span><span class="o">.</span><span class="n">stderr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stderr</span>
  139. <span class="k">def</span> <span class="nf">write</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span>
  140. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
  141. <span class="bp">self</span><span class="o">.</span><span class="n">stderr</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
  142. <span class="k">def</span> <span class="fm">__getattr__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">attr</span><span class="p">):</span>
  143. <span class="k">return</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stderr</span><span class="p">,</span> <span class="n">attr</span><span class="p">)</span>
  144. <span class="k">class</span> <span class="nc">StdoutTee</span><span class="p">(</span><span class="n">BufferWriter</span><span class="p">):</span>
  145. <span class="sd">&quot;&quot;&quot;Duplicate the stdout stream to save it into a given file.&quot;&quot;&quot;</span>
  146. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">buffer</span><span class="p">,</span> <span class="n">buffer_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lock</span><span class="p">:</span> <span class="n">Lock</span><span class="p">):</span>
  147. <span class="sd">&quot;&quot;&quot;</span>
  148. <span class="sd"> :param filename: Name of the file where to write the bugger</span>
  149. <span class="sd"> :param buffer: Buffer object</span>
  150. <span class="sd"> :param buffer_size: Number of chars to be buffered before writing the buffer on disk.</span>
  151. <span class="sd"> :param lock: Thread lock to prevent multiple threads to write at the same time</span>
  152. <span class="sd"> &quot;&quot;&quot;</span>
  153. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="n">buffer</span><span class="p">,</span> <span class="n">buffer_size</span><span class="p">,</span> <span class="n">lock</span><span class="p">)</span>
  154. <span class="bp">self</span><span class="o">.</span><span class="n">stdout</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span>
  155. <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span> <span class="o">=</span> <span class="bp">self</span>
  156. <span class="k">def</span> <span class="fm">__del__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  157. <span class="n">sys</span><span class="o">.</span><span class="n">stdout</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stdout</span>
  158. <span class="k">def</span> <span class="nf">write</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span>
  159. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
  160. <span class="bp">self</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
  161. <span class="k">def</span> <span class="fm">__getattr__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">attr</span><span class="p">):</span>
  162. <span class="k">return</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stdout</span><span class="p">,</span> <span class="n">attr</span><span class="p">)</span>
  163. <span class="k">def</span> <span class="nf">copy_file</span><span class="p">(</span><span class="n">src_filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">dest_filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">copy_mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;w&quot;</span><span class="p">):</span>
  164. <span class="sd">&quot;&quot;&quot;Copy a file from source to destination. Also works when the destination folder does not exist.&quot;&quot;&quot;</span>
  165. <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="n">dest_filename</span><span class="p">),</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  166. <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">src_filename</span><span class="p">):</span>
  167. <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">src_filename</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">src</span><span class="p">:</span>
  168. <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">dest_filename</span><span class="p">,</span> <span class="n">copy_mode</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">dst</span><span class="p">:</span>
  169. <span class="n">dst</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">src</span><span class="o">.</span><span class="n">read</span><span class="p">())</span>
  170. <div class="viewcode-block" id="ConsoleSink"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.auto_logging.ConsoleSink">[docs]</a><span class="k">class</span> <span class="nc">ConsoleSink</span><span class="p">:</span>
  171. <span class="sd">&quot;&quot;&quot;Singleton responsible to sink the console streams (stdout/stderr) into a file.&quot;&quot;&quot;</span>
  172. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  173. <span class="bp">self</span><span class="o">.</span><span class="n">_setup</span><span class="p">()</span>
  174. <span class="n">atexit</span><span class="o">.</span><span class="n">register</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_flush</span><span class="p">)</span> <span class="c1"># Flush at the end of the process</span>
  175. <span class="nd">@multi_process_safe</span>
  176. <span class="k">def</span> <span class="nf">_setup</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  177. <span class="sd">&quot;&quot;&quot;On instantiation, setup the default sink file.&quot;&quot;&quot;</span>
  178. <span class="n">filename</span> <span class="o">=</span> <span class="n">Path</span><span class="o">.</span><span class="n">home</span><span class="p">()</span> <span class="o">/</span> <span class="s2">&quot;sg_logs&quot;</span> <span class="o">/</span> <span class="s2">&quot;console.log&quot;</span>
  179. <span class="n">filename</span><span class="o">.</span><span class="n">parent</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  180. <span class="bp">self</span><span class="o">.</span><span class="n">filename</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="n">filename</span><span class="p">)</span>
  181. <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filename</span><span class="p">),</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  182. <span class="n">buffer</span> <span class="o">=</span> <span class="n">StringIO</span><span class="p">()</span>
  183. <span class="n">lock</span> <span class="o">=</span> <span class="n">Lock</span><span class="p">()</span>
  184. <span class="bp">self</span><span class="o">.</span><span class="n">stdout</span> <span class="o">=</span> <span class="n">StdoutTee</span><span class="p">(</span><span class="n">filename</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">filename</span><span class="p">,</span> <span class="n">buffer</span><span class="o">=</span><span class="n">buffer</span><span class="p">,</span> <span class="n">buffer_size</span><span class="o">=</span><span class="n">BufferWriter</span><span class="o">.</span><span class="n">FILE_BUFFER_SIZE</span><span class="p">,</span> <span class="n">lock</span><span class="o">=</span><span class="n">lock</span><span class="p">)</span>
  185. <span class="bp">self</span><span class="o">.</span><span class="n">stderr</span> <span class="o">=</span> <span class="n">StderrTee</span><span class="p">(</span><span class="n">filename</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">filename</span><span class="p">,</span> <span class="n">buffer</span><span class="o">=</span><span class="n">buffer</span><span class="p">,</span> <span class="n">buffer_size</span><span class="o">=</span><span class="n">BufferWriter</span><span class="o">.</span><span class="n">FILE_BUFFER_SIZE</span><span class="p">,</span> <span class="n">lock</span><span class="o">=</span><span class="n">lock</span><span class="p">)</span>
  186. <span class="c1"># We don&#39;t want to rewrite this for subprocesses when using DDP.</span>
  187. <span class="k">if</span> <span class="n">is_main_process</span><span class="p">():</span>
  188. <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filename</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;w&quot;</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
  189. <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s2">&quot;============================================================</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">)</span>
  190. <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;New run started at </span><span class="si">{</span><span class="n">datetime</span><span class="o">.</span><span class="n">now</span><span class="p">()</span><span class="o">.</span><span class="n">strftime</span><span class="p">(</span><span class="s2">&quot;%Y-%m-</span><span class="si">%d</span><span class="s2">.%H:%M:%S.</span><span class="si">%f</span><span class="s2">&quot;</span><span class="p">)</span><span class="si">}</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
  191. <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;sys.argv: &quot;</span><span class="si">{</span><span class="s2">&quot; &quot;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">sys</span><span class="o">.</span><span class="n">argv</span><span class="p">)</span><span class="si">}</span><span class="s1">&quot;</span><span class="se">\n</span><span class="s1">&#39;</span><span class="p">)</span>
  192. <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s2">&quot;============================================================</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">)</span>
  193. <span class="bp">self</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;The console stream is logged into </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">filename</span><span class="si">}</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">)</span>
  194. <span class="nd">@multi_process_safe</span>
  195. <span class="k">def</span> <span class="nf">_set_location</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
  196. <span class="sd">&quot;&quot;&quot;Copy and redirect the sink file into another location.&quot;&quot;&quot;</span>
  197. <span class="bp">self</span><span class="o">.</span><span class="n">_flush</span><span class="p">()</span>
  198. <span class="n">prev_filename</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">filename</span>
  199. <span class="n">copy_file</span><span class="p">(</span><span class="n">src_filename</span><span class="o">=</span><span class="n">prev_filename</span><span class="p">,</span> <span class="n">dest_filename</span><span class="o">=</span><span class="n">filename</span><span class="p">,</span> <span class="n">copy_mode</span><span class="o">=</span><span class="s2">&quot;a&quot;</span><span class="p">)</span>
  200. <span class="bp">self</span><span class="o">.</span><span class="n">filename</span> <span class="o">=</span> <span class="n">filename</span>
  201. <span class="bp">self</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">filename</span> <span class="o">=</span> <span class="n">filename</span>
  202. <span class="bp">self</span><span class="o">.</span><span class="n">stderr</span><span class="o">.</span><span class="n">filename</span> <span class="o">=</span> <span class="n">filename</span>
  203. <span class="bp">self</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;The console stream is now moved to </span><span class="si">{</span><span class="n">filename</span><span class="si">}</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">)</span>
  204. <div class="viewcode-block" id="ConsoleSink.set_location"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.auto_logging.ConsoleSink.set_location">[docs]</a> <span class="nd">@staticmethod</span>
  205. <span class="k">def</span> <span class="nf">set_location</span><span class="p">(</span><span class="n">filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
  206. <span class="sd">&quot;&quot;&quot;Copy and redirect the sink file into another location.&quot;&quot;&quot;</span>
  207. <span class="n">_console_sink</span><span class="o">.</span><span class="n">_set_location</span><span class="p">(</span><span class="n">filename</span><span class="p">)</span></div>
  208. <span class="nd">@multi_process_safe</span>
  209. <span class="k">def</span> <span class="nf">_flush</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  210. <span class="sd">&quot;&quot;&quot;Force the flush on stdout and stderr.&quot;&quot;&quot;</span>
  211. <span class="bp">self</span><span class="o">.</span><span class="n">stdout</span><span class="o">.</span><span class="n">flush</span><span class="p">(</span><span class="n">force</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  212. <span class="bp">self</span><span class="o">.</span><span class="n">stderr</span><span class="o">.</span><span class="n">flush</span><span class="p">(</span><span class="n">force</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  213. <div class="viewcode-block" id="ConsoleSink.flush"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.auto_logging.ConsoleSink.flush">[docs]</a> <span class="nd">@staticmethod</span>
  214. <span class="k">def</span> <span class="nf">flush</span><span class="p">():</span>
  215. <span class="sd">&quot;&quot;&quot;Force the flush on stdout and stderr.&quot;&quot;&quot;</span>
  216. <span class="n">_console_sink</span><span class="o">.</span><span class="n">_flush</span><span class="p">()</span></div>
  217. <div class="viewcode-block" id="ConsoleSink.get_filename"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.auto_logging.ConsoleSink.get_filename">[docs]</a> <span class="nd">@staticmethod</span>
  218. <span class="k">def</span> <span class="nf">get_filename</span><span class="p">():</span>
  219. <span class="sd">&quot;&quot;&quot;Get the filename of the sink.&quot;&quot;&quot;</span>
  220. <span class="k">return</span> <span class="n">_console_sink</span><span class="o">.</span><span class="n">filename</span></div></div>
  221. <span class="n">_console_sink</span> <span class="o">=</span> <span class="n">ConsoleSink</span><span class="p">()</span>
  222. </pre></div>
  223. </div>
  224. </div>
  225. <footer>
  226. <hr/>
  227. <div role="contentinfo">
  228. <p>&#169; Copyright 2021, SuperGradients team.</p>
  229. </div>
  230. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  231. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  232. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  233. </footer>
  234. </div>
  235. </div>
  236. </section>
  237. </div>
  238. <script>
  239. jQuery(function () {
  240. SphinxRtdTheme.Navigation.enable(true);
  241. });
  242. </script>
  243. </body>
  244. </html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.common.aws_connection.aws_connector &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.common.aws_connection.aws_connector &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -93,7 +95,7 @@
 <span class="kn">from</span> <span class="nn">botocore.exceptions</span> <span class="kn">import</span> <span class="n">ClientError</span><span class="p">,</span> <span class="n">ProfileNotFound</span>
 <span class="kn">from</span> <span class="nn">botocore.exceptions</span> <span class="kn">import</span> <span class="n">ClientError</span><span class="p">,</span> <span class="n">ProfileNotFound</span>
 
 
 
 
-<div class="viewcode-block" id="AWSConnector"><a class="viewcode-back" href="../../../../super_gradients.common.aws_connection.html#super_gradients.common.AWSConnector">[docs]</a><span class="k">class</span> <span class="nc">AWSConnector</span><span class="p">:</span>
+<div class="viewcode-block" id="AWSConnector"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.AWSConnector">[docs]</a><span class="k">class</span> <span class="nc">AWSConnector</span><span class="p">:</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    AWSConnector - Connects to AWS using Credentials File or IAM Role</span>
 <span class="sd">    AWSConnector - Connects to AWS using Credentials File or IAM Role</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
@@ -133,7 +135,7 @@
                     <span class="n">ex</span><span class="p">))</span>
                     <span class="n">ex</span><span class="p">))</span>
             <span class="k">return</span> <span class="kc">None</span>
             <span class="k">return</span> <span class="kc">None</span>
 
 
-<div class="viewcode-block" id="AWSConnector.get_aws_session"><a class="viewcode-back" href="../../../../super_gradients.common.aws_connection.html#super_gradients.common.AWSConnector.get_aws_session">[docs]</a>    <span class="nd">@staticmethod</span>
+<div class="viewcode-block" id="AWSConnector.get_aws_session"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.AWSConnector.get_aws_session">[docs]</a>    <span class="nd">@staticmethod</span>
     <span class="k">def</span> <span class="nf">get_aws_session</span><span class="p">(</span><span class="n">profile_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">boto3</span><span class="o">.</span><span class="n">Session</span><span class="p">:</span>
     <span class="k">def</span> <span class="nf">get_aws_session</span><span class="p">(</span><span class="n">profile_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">boto3</span><span class="o">.</span><span class="n">Session</span><span class="p">:</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        get_aws_session - Connects to AWS to retrieve an AWS Session</span>
 <span class="sd">        get_aws_session - Connects to AWS to retrieve an AWS Session</span>
@@ -149,7 +151,7 @@
 
 
         <span class="k">return</span> <span class="n">aws_session</span></div>
         <span class="k">return</span> <span class="n">aws_session</span></div>
 
 
-<div class="viewcode-block" id="AWSConnector.get_aws_client_for_service_name"><a class="viewcode-back" href="../../../../super_gradients.common.aws_connection.html#super_gradients.common.AWSConnector.get_aws_client_for_service_name">[docs]</a>    <span class="nd">@staticmethod</span>
+<div class="viewcode-block" id="AWSConnector.get_aws_client_for_service_name"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.AWSConnector.get_aws_client_for_service_name">[docs]</a>    <span class="nd">@staticmethod</span>
     <span class="k">def</span> <span class="nf">get_aws_client_for_service_name</span><span class="p">(</span><span class="n">profile_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">service_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">boto3</span><span class="o">.</span><span class="n">Session</span><span class="o">.</span><span class="n">client</span><
     <span class="k">def</span> <span class="nf">get_aws_client_for_service_name</span><span class="p">(</span><span class="n">profile_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">service_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">boto3</span><span class="o">.</span><span class="n">Session</span><span class="o">.</span><span class="n">client</span><
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        get_aws_client_for_service_name - Connects to AWS to retrieve the relevant Client</span>
 <span class="sd">        get_aws_client_for_service_name - Connects to AWS to retrieve the relevant Client</span>
@@ -166,7 +168,7 @@
 
 
         <span class="k">return</span> <span class="n">aws_session</span><span class="o">.</span><span class="n">client</span><span class="p">(</span><span class="n">service_name</span><span class="o">=</span><span class="n">service_name</span><span class="p">)</span></div>
         <span class="k">return</span> <span class="n">aws_session</span><span class="o">.</span><span class="n">client</span><span class="p">(</span><span class="n">service_name</span><span class="o">=</span><span class="n">service_name</span><span class="p">)</span></div>
 
 
-<div class="viewcode-block" id="AWSConnector.get_aws_resource_for_service_name"><a class="viewcode-back" href="../../../../super_gradients.common.aws_connection.html#super_gradients.common.AWSConnector.get_aws_resource_for_service_name">[docs]</a>    <span class="nd">@staticmethod</span>
+<div class="viewcode-block" id="AWSConnector.get_aws_resource_for_service_name"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.AWSConnector.get_aws_resource_for_service_name">[docs]</a>    <span class="nd">@staticmethod</span>
     <span class="k">def</span> <span class="nf">get_aws_resource_for_service_name</span><span class="p">(</span><span class="n">profile_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">service_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">boto3</span><span class="o">.</span><span class="n">Session</span><span class="o">.</span><span class="n">resource</sp
     <span class="k">def</span> <span class="nf">get_aws_resource_for_service_name</span><span class="p">(</span><span class="n">profile_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">service_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">boto3</span><span class="o">.</span><span class="n">Session</span><span class="o">.</span><span class="n">resource</sp
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        Connects to AWS to retrieve the relevant Resource (More functionality then Client)</span>
 <span class="sd">        Connects to AWS to retrieve the relevant Resource (More functionality then Client)</span>
@@ -183,7 +185,7 @@
 
 
         <span class="k">return</span> <span class="n">aws_session</span><span class="o">.</span><span class="n">resource</span><span class="p">(</span><span class="n">service_name</span><span class="o">=</span><span class="n">service_name</span><span class="p">)</span></div>
         <span class="k">return</span> <span class="n">aws_session</span><span class="o">.</span><span class="n">resource</span><span class="p">(</span><span class="n">service_name</span><span class="o">=</span><span class="n">service_name</span><span class="p">)</span></div>
 
 
-<div class="viewcode-block" id="AWSConnector.is_client_error"><a class="viewcode-back" href="../../../../super_gradients.common.aws_connection.html#super_gradients.common.AWSConnector.is_client_error">[docs]</a>    <span class="nd">@staticmethod</span>
+<div class="viewcode-block" id="AWSConnector.is_client_error"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.AWSConnector.is_client_error">[docs]</a>    <span class="nd">@staticmethod</span>
     <span class="k">def</span> <span class="nf">is_client_error</span><span class="p">(</span><span class="n">code</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">is_client_error</span><span class="p">(</span><span class="n">code</span><span class="p">):</span>
         <span class="n">e</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">exc_info</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span>
         <span class="n">e</span> <span class="o">=</span> <span class="n">sys</span><span class="o">.</span><span class="n">exc_info</span><span class="p">()[</span><span class="mi">1</span><span class="p">]</span>
         <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">e</span><span class="p">,</span> <span class="n">ClientError</span><span class="p">)</span> <span class="ow">and</span> <span class="n">e</span><span class="o">.</span><span class="n">response</span><span class="p">[</span><span class="s2">&quot;Error&quot;</span><span class="p">][</span><span class="s2">&quot;Code&quot;</span><span class="p">]</span> <span class="o">==</span> <span class
         <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">e</span><span class="p">,</span> <span class="n">ClientError</span><span class="p">)</span> <span class="ow">and</span> <span class="n">e</span><span class="o">.</span><span class="n">response</span><span class="p">[</span><span class="s2">&quot;Error&quot;</span><span class="p">][</span><span class="s2">&quot;Code&quot;</span><span class="p">]</span> <span class="o">==</span> <span class
@@ -218,4 +220,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.common.aws_connection.aws_secrets_manager_connector &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.common.aws_connection.aws_secrets_manager_connector</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.common.aws_connection.aws_secrets_manager_connector</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">json</span>
  84. <span class="kn">import</span> <span class="nn">logging</span>
  85. <span class="kn">from</span> <span class="nn">super_gradients.common</span> <span class="kn">import</span> <span class="n">AWSConnector</span>
  86. <span class="kn">from</span> <span class="nn">super_gradients.common</span> <span class="kn">import</span> <span class="n">explicit_params_validation</span>
  87. <div class="viewcode-block" id="AWSSecretsManagerConnector"><a class="viewcode-back" href="../../../../super_gradients.common.aws_connection.html#super_gradients.common.aws_connection.aws_secrets_manager_connector.AWSSecretsManagerConnector">[docs]</a><span class="k">class</span> <span class="nc">AWSSecretsManagerConnector</span><span class="p">:</span>
  88. <span class="sd">&quot;&quot;&quot;</span>
  89. <span class="sd"> AWSSecretsManagerConnector - This class handles the AWS Secrets Manager connection</span>
  90. <span class="sd"> &quot;&quot;&quot;</span>
  91. <span class="vm">__slots__</span> <span class="o">=</span> <span class="p">[]</span> <span class="c1"># Making the class immutable for runtime safety</span>
  92. <span class="n">current_environment_client</span> <span class="o">=</span> <span class="kc">None</span>
  93. <span class="n">DECI_ENVIRONMENTS</span> <span class="o">=</span> <span class="p">[</span><span class="s1">&#39;research&#39;</span><span class="p">,</span> <span class="s1">&#39;development&#39;</span><span class="p">,</span> <span class="s1">&#39;staging&#39;</span><span class="p">,</span> <span class="s1">&#39;production&#39;</span><span class="p">]</span>
  94. <span class="nd">@staticmethod</span>
  95. <span class="nd">@explicit_params_validation</span><span class="p">(</span><span class="n">validation_type</span><span class="o">=</span><span class="s1">&#39;NoneOrEmpty&#39;</span><span class="p">)</span>
  96. <span class="k">def</span> <span class="nf">get_secret_value_for_secret_key</span><span class="p">(</span><span class="n">aws_env</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">secret_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">secret_key</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
  97. <span class="sd">&quot;&quot;&quot;</span>
  98. <span class="sd"> get_secret_value_for_secret_key - Gets a Secret Value from AWS Secrets Manager for the Provided Key</span>
  99. <span class="sd"> :param aws_env: The environment to get the secret for</span>
  100. <span class="sd"> :param secret_name: The Secret Name stored in Secrets Manager</span>
  101. <span class="sd"> :param secret_key: The Secret Key To retrieve it&#39;s value from AWS</span>
  102. <span class="sd"> :return: str: The Secret Value</span>
  103. <span class="sd"> &quot;&quot;&quot;</span>
  104. <span class="n">current_class_name</span> <span class="o">=</span> <span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span>
  105. <span class="n">logger</span> <span class="o">=</span> <span class="n">logging</span><span class="o">.</span><span class="n">getLogger</span><span class="p">(</span><span class="n">current_class_name</span><span class="p">)</span>
  106. <span class="n">secret_key</span> <span class="o">=</span> <span class="n">secret_key</span><span class="o">.</span><span class="n">upper</span><span class="p">()</span>
  107. <span class="n">aws_secrets_dict</span> <span class="o">=</span> <span class="n">AWSSecretsManagerConnector</span><span class="o">.</span><span class="n">__get_secrets_manager_dict_for_secret_name</span><span class="p">(</span>
  108. <span class="n">aws_env</span><span class="o">=</span><span class="n">aws_env</span><span class="p">,</span> <span class="n">secret_name</span><span class="o">=</span><span class="n">secret_name</span><span class="p">)</span>
  109. <span class="n">secret_key</span> <span class="o">=</span> <span class="s1">&#39;.&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="n">aws_env</span><span class="o">.</span><span class="n">upper</span><span class="p">(),</span> <span class="n">secret_key</span><span class="p">])</span>
  110. <span class="k">if</span> <span class="n">secret_key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">aws_secrets_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
  111. <span class="n">error</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;[</span><span class="si">{</span><span class="n">current_class_name</span><span class="si">}</span><span class="s1">] - Secret Key (</span><span class="si">{</span><span class="n">secret_key</span><span class="si">}</span><span class="s1">) not Found in AWS Secret: &#39;</span> <span class="o">+</span> <span class="n">secret_name</span>
  112. <span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="n">error</span><span class="p">)</span>
  113. <span class="k">raise</span> <span class="ne">EnvironmentError</span><span class="p">(</span><span class="n">error</span><span class="p">)</span>
  114. <span class="k">else</span><span class="p">:</span>
  115. <span class="k">return</span> <span class="n">aws_secrets_dict</span><span class="p">[</span><span class="n">secret_key</span><span class="p">]</span>
  116. <span class="nd">@staticmethod</span>
  117. <span class="nd">@explicit_params_validation</span><span class="p">(</span><span class="n">validation_type</span><span class="o">=</span><span class="s1">&#39;NoneOrEmpty&#39;</span><span class="p">)</span>
  118. <span class="k">def</span> <span class="nf">get_secret_values_dict_for_secret_key_properties</span><span class="p">(</span><span class="n">env</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">secret_key</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">secret_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
  119. <span class="n">db_properties_set</span><span class="p">:</span> <span class="nb">set</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span class="p">:</span>
  120. <span class="sd">&quot;&quot;&quot;</span>
  121. <span class="sd"> get_config_dict - Returns the config dict of the properties from the properties dict</span>
  122. <span class="sd"> :param env: The environment to open the dict for</span>
  123. <span class="sd"> :param secret_key: The Secret Key</span>
  124. <span class="sd"> :param secret_name: The Secret to Retrieve to from AWS secrets manager (usually project name)</span>
  125. <span class="sd"> :param db_properties_set: The set of the properties to get secrets values for</span>
  126. <span class="sd"> :return: dict The secrets dict for the requested property</span>
  127. <span class="sd"> &quot;&quot;&quot;</span>
  128. <span class="n">current_class_name</span> <span class="o">=</span> <span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span>
  129. <span class="n">logger</span> <span class="o">=</span> <span class="n">logging</span><span class="o">.</span><span class="n">getLogger</span><span class="p">(</span><span class="n">current_class_name</span><span class="p">)</span>
  130. <span class="n">aws_secrets_dict</span> <span class="o">=</span> <span class="n">AWSSecretsManagerConnector</span><span class="o">.</span><span class="n">__get_secrets_manager_dict_for_secret_name</span><span class="p">(</span>
  131. <span class="n">aws_env</span><span class="o">=</span><span class="n">env</span><span class="p">,</span> <span class="n">secret_name</span><span class="o">=</span><span class="n">secret_name</span><span class="p">)</span>
  132. <span class="n">aws_env_safe_secrets</span> <span class="o">=</span> <span class="p">{}</span>
  133. <span class="c1"># FILL THE DICT VALUES FROM THE AWS SECRETS RESPONSE</span>
  134. <span class="k">if</span> <span class="n">db_properties_set</span><span class="p">:</span>
  135. <span class="k">for</span> <span class="n">secret_key_property</span> <span class="ow">in</span> <span class="n">db_properties_set</span><span class="p">:</span>
  136. <span class="n">secret_key_to_retrieve</span> <span class="o">=</span> <span class="s1">&#39;.&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="n">env</span><span class="o">.</span><span class="n">upper</span><span class="p">(),</span> <span class="n">secret_key</span><span class="p">,</span> <span class="n">secret_key_property</span><span class="p">])</span>
  137. <span class="k">if</span> <span class="n">secret_key_to_retrieve</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">aws_secrets_dict</span><span class="p">:</span>
  138. <span class="n">error</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;[</span><span class="si">{</span><span class="n">current_class_name</span><span class="si">}</span><span class="s1">] - Error retrieving data from AWS Secrets Manager for Secret Key &quot;</span><span class="si">{</span><span class="n">secret_name</span><span class="si">}</span><span class="s1">&quot;: &#39;</span> \
  139. <span class="sa">f</span><span class="s1">&#39;The secret property &quot;</span><span class="si">{</span><span class="n">secret_key_property</span><span class="si">}</span><span class="s1">&quot; Does Not Exist&#39;</span>
  140. <span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="n">error</span><span class="p">)</span>
  141. <span class="k">raise</span> <span class="ne">EnvironmentError</span><span class="p">(</span><span class="n">error</span><span class="p">)</span>
  142. <span class="k">else</span><span class="p">:</span>
  143. <span class="n">env_stripped_key_name</span> <span class="o">=</span> <span class="n">secret_key_to_retrieve</span><span class="o">.</span><span class="n">lstrip</span><span class="p">(</span><span class="n">env</span><span class="o">.</span><span class="n">upper</span><span class="p">())</span><span class="o">.</span><span class="n">lstrip</span><span class="p">(</span><span class="s1">&#39;.&#39;</span><span class="p">)</span>
  144. <span class="n">aws_env_safe_secrets</span><span class="p">[</span><span class="n">env_stripped_key_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">aws_secrets_dict</span><span class="p">[</span><span class="n">secret_key_to_retrieve</span><span class="p">]</span>
  145. <span class="k">else</span><span class="p">:</span>
  146. <span class="c1"># &quot;db_properties_set&quot; is not specified - validating and returning all the secret keys and values for</span>
  147. <span class="c1"># the secret name.</span>
  148. <span class="k">for</span> <span class="n">secret_key_name</span><span class="p">,</span> <span class="n">secret_value</span> <span class="ow">in</span> <span class="n">aws_secrets_dict</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
  149. <span class="n">secret_key_to_retrieve</span> <span class="o">=</span> <span class="s1">&#39;.&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="n">env</span><span class="o">.</span><span class="n">upper</span><span class="p">(),</span> <span class="n">secret_key</span><span class="p">])</span>
  150. <span class="k">assert</span> <span class="n">secret_key_name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span>
  151. <span class="n">env</span><span class="o">.</span><span class="n">upper</span><span class="p">()),</span> <span class="sa">f</span><span class="s1">&#39;The secret key property &quot;</span><span class="si">{</span><span class="n">secret_key_name</span><span class="si">}</span><span class="s1">&quot;, found in secret named </span><span class="si">{</span><span class="n">secret_name</span><span class="si">}</span><span class="s1">, is not following the convention of &#39;</span> \
  152. <span class="sa">f</span><span class="s1">&#39;environment prefix. please add the environment prefix &quot;</span><span class="si">{</span><span class="n">env</span><span class="o">.</span><span class="n">upper</span><span class="p">()</span><span class="si">}</span><span class="s1">&quot; to property &quot;</span><span class="si">{</span><span class="n">secret_key_name</span><span class="si">}</span><span class="s1">&quot;&#39;</span>
  153. <span class="k">if</span> <span class="n">secret_key_name</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="n">secret_key_to_retrieve</span><span class="p">):</span>
  154. <span class="n">env_stripped_key_name</span> <span class="o">=</span> <span class="n">secret_key_name</span><span class="o">.</span><span class="n">lstrip</span><span class="p">(</span><span class="n">env</span><span class="o">.</span><span class="n">upper</span><span class="p">())</span><span class="o">.</span><span class="n">lstrip</span><span class="p">(</span><span class="s1">&#39;.&#39;</span><span class="p">)</span>
  155. <span class="n">aws_env_safe_secrets</span><span class="p">[</span><span class="n">env_stripped_key_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">secret_value</span>
  156. <span class="k">return</span> <span class="n">aws_env_safe_secrets</span>
  157. <span class="nd">@staticmethod</span>
  158. <span class="k">def</span> <span class="nf">__get_secrets_manager_dict_for_secret_name</span><span class="p">(</span><span class="n">aws_env</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">secret_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span class="p">:</span>
  159. <span class="sd">&quot;&quot;&quot;</span>
  160. <span class="sd"> __get_secrets_manager_dict_for_secret_name</span>
  161. <span class="sd"> :param aws_env: The environment to open the dict for</span>
  162. <span class="sd"> :param secret_name: The Secret to Retrieve to from AWS secrets manager (usually project name)</span>
  163. <span class="sd"> :return: python Dictionary with the key/value pairs stored in AWS Secrets Manager</span>
  164. <span class="sd"> &quot;&quot;&quot;</span>
  165. <span class="n">current_class_name</span> <span class="o">=</span> <span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span>
  166. <span class="n">logger</span> <span class="o">=</span> <span class="n">logging</span><span class="o">.</span><span class="n">getLogger</span><span class="p">(</span><span class="n">current_class_name</span><span class="p">)</span>
  167. <span class="n">secrets_path</span> <span class="o">=</span> <span class="n">AWSSecretsManagerConnector</span><span class="o">.</span><span class="n">__get_secrets_path_from_secret_name</span><span class="p">(</span><span class="n">aws_env</span><span class="p">,</span> <span class="n">secret_name</span><span class="p">)</span>
  168. <span class="k">try</span><span class="p">:</span>
  169. <span class="k">if</span> <span class="ow">not</span> <span class="n">AWSSecretsManagerConnector</span><span class="o">.</span><span class="n">current_environment_client</span><span class="p">:</span>
  170. <span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="s1">&#39;Initializing a new secrets manager client...&#39;</span><span class="p">)</span>
  171. <span class="n">AWSSecretsManagerConnector</span><span class="o">.</span><span class="n">current_environment_client</span> <span class="o">=</span> <span class="n">AWSConnector</span><span class="o">.</span><span class="n">get_aws_client_for_service_name</span><span class="p">(</span>
  172. <span class="n">profile_name</span><span class="o">=</span><span class="n">aws_env</span><span class="p">,</span>
  173. <span class="n">service_name</span><span class="o">=</span><span class="s1">&#39;secretsmanager&#39;</span><span class="p">)</span>
  174. <span class="n">logger</span><span class="o">.</span><span class="n">debug</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Fetching the secret &quot;</span><span class="si">{</span><span class="n">secret_name</span><span class="si">}</span><span class="s1">&quot; in env &quot;</span><span class="si">{</span><span class="n">aws_env</span><span class="si">}</span><span class="s1">&quot;&#39;</span><span class="p">)</span>
  175. <span class="n">aws_secrets</span> <span class="o">=</span> <span class="n">AWSSecretsManagerConnector</span><span class="o">.</span><span class="n">current_environment_client</span><span class="o">.</span><span class="n">get_secret_value</span><span class="p">(</span><span class="n">SecretId</span><span class="o">=</span><span class="n">secrets_path</span><span class="p">)</span>
  176. <span class="n">aws_secrets_dict</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">loads</span><span class="p">(</span><span class="n">aws_secrets</span><span class="p">[</span><span class="s1">&#39;SecretString&#39;</span><span class="p">])</span>
  177. <span class="k">return</span> <span class="n">aws_secrets_dict</span>
  178. <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>
  179. <span class="n">error</span> <span class="o">=</span> <span class="s1">&#39;[&#39;</span> <span class="o">+</span> <span class="n">current_class_name</span> <span class="o">+</span> <span class="s1">&#39;] - Caught Exception while trying to connect to aws to get credentials from secrets manager: &#39;</span> <span class="o">+</span> <span class="s1">&#39;&quot;&#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span>
  180. <span class="n">ex</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;&quot;&#39;</span> <span class="o">+</span> <span class="s1">&#39; for &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">secrets_path</span><span class="p">)</span>
  181. <span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="n">error</span><span class="p">)</span>
  182. <span class="k">raise</span> <span class="ne">EnvironmentError</span><span class="p">(</span><span class="n">error</span><span class="p">)</span>
  183. <span class="nd">@staticmethod</span>
  184. <span class="k">def</span> <span class="nf">__get_secrets_path_from_secret_name</span><span class="p">(</span><span class="n">aws_env</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">secret_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
  185. <span class="sd">&quot;&quot;&quot;</span>
  186. <span class="sd"> __get_secrets_path_from_secret_name - Extracts the full secret path based on the Environment</span>
  187. <span class="sd"> :param aws_env: Env</span>
  188. <span class="sd"> :param secret_name: Secret Name</span>
  189. <span class="sd"> :return: str: The full secret path</span>
  190. <span class="sd"> &quot;&quot;&quot;</span>
  191. <span class="n">current_class_name</span> <span class="o">=</span> <span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span>
  192. <span class="n">logger</span> <span class="o">=</span> <span class="n">logging</span><span class="o">.</span><span class="n">getLogger</span><span class="p">(</span><span class="n">current_class_name</span><span class="p">)</span>
  193. <span class="c1"># Checking for lowercase exact match, in order to prevent any implicit usage of the environments.</span>
  194. <span class="k">if</span> <span class="n">aws_env</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">AWSSecretsManagerConnector</span><span class="o">.</span><span class="n">DECI_ENVIRONMENTS</span><span class="p">:</span>
  195. <span class="n">logger</span><span class="o">.</span><span class="n">critical</span><span class="p">(</span><span class="s1">&#39;[&#39;</span> <span class="o">+</span> <span class="n">current_class_name</span> <span class="o">+</span> <span class="s1">&#39; ] - wrong environment param... Exiting&#39;</span><span class="p">)</span>
  196. <span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">&#39;[&#39;</span> <span class="o">+</span> <span class="n">current_class_name</span> <span class="o">+</span> <span class="s1">&#39;] - wrong environment param&#39;</span><span class="p">)</span>
  197. <span class="n">secrets_path</span> <span class="o">=</span> <span class="s1">&#39;/&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="n">aws_env</span><span class="p">,</span> <span class="n">secret_name</span><span class="p">])</span>
  198. <span class="k">return</span> <span class="n">secrets_path</span></div>
  199. </pre></div>
  200. </div>
  201. </div>
  202. <footer>
  203. <hr/>
  204. <div role="contentinfo">
  205. <p>&#169; Copyright 2021, SuperGradients team.</p>
  206. </div>
  207. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  208. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  209. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  210. </footer>
  211. </div>
  212. </div>
  213. </section>
  214. </div>
  215. <script>
  216. jQuery(function () {
  217. SphinxRtdTheme.Navigation.enable(true);
  218. });
  219. </script>
  220. </body>
  221. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.common.crash_handler.crash_handler &mdash; SuperGradients 3.0.3 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
  11. <!--[if lt IE 9]>
  12. <script src="../../../../_static/js/html5shiv.min.js"></script>
  13. <![endif]-->
  14. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  15. <script src="../../../../_static/jquery.js"></script>
  16. <script src="../../../../_static/underscore.js"></script>
  17. <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
  18. <script src="../../../../_static/doctools.js"></script>
  19. <script src="../../../../_static/sphinx_highlight.js"></script>
  20. <script src="../../../../_static/js/theme.js"></script>
  21. <link rel="index" title="Index" href="../../../../genindex.html" />
  22. <link rel="search" title="Search" href="../../../../search.html" />
  23. </head>
  24. <body class="wy-body-for-nav">
  25. <div class="wy-grid-for-nav">
  26. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  27. <div class="wy-side-scroll">
  28. <div class="wy-side-nav-search" >
  29. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  30. </a>
  31. <div role="search">
  32. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  33. <input type="text" name="q" placeholder="Search docs" />
  34. <input type="hidden" name="check_keywords" value="yes" />
  35. <input type="hidden" name="area" value="default" />
  36. </form>
  37. </div>
  38. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  39. <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
  40. <ul>
  41. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
  42. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
  45. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
  46. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
  47. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
  57. </ul>
  58. <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
  59. <ul>
  60. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  61. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  62. </ul>
  63. </div>
  64. </div>
  65. </nav>
  66. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  67. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  68. <a href="../../../../index.html">SuperGradients</a>
  69. </nav>
  70. <div class="wy-nav-content">
  71. <div class="rst-content">
  72. <div role="navigation" aria-label="Page navigation">
  73. <ul class="wy-breadcrumbs">
  74. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  75. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  76. <li>super_gradients.common.crash_handler.crash_handler</li>
  77. <li class="wy-breadcrumbs-aside">
  78. </li>
  79. </ul>
  80. <hr/>
  81. </div>
  82. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  83. <div itemprop="articleBody">
  84. <h1>Source code for super_gradients.common.crash_handler.crash_handler</h1><div class="highlight"><pre>
  85. <span></span><span class="kn">import</span> <span class="nn">sys</span>
  86. <span class="kn">from</span> <span class="nn">super_gradients.common.crash_handler.crash_tips_setup</span> <span class="kn">import</span> <span class="n">setup_crash_tips</span>
  87. <span class="kn">from</span> <span class="nn">super_gradients.common.crash_handler.exception_monitoring_setup</span> <span class="kn">import</span> <span class="n">setup_pro_user_monitoring</span>
  88. <span class="kn">from</span> <span class="nn">super_gradients.common.crash_handler.exception</span> <span class="kn">import</span> <span class="n">register_exceptions</span>
  89. <div class="viewcode-block" id="setup_crash_handler"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.setup_crash_handler">[docs]</a><span class="k">def</span> <span class="nf">setup_crash_handler</span><span class="p">():</span>
  90. <span class="sd">&quot;&quot;&quot;Setup the environment to handle crashes, with crash tips and more.&quot;&quot;&quot;</span>
  91. <span class="n">is_setup_crash_tips</span> <span class="o">=</span> <span class="n">setup_crash_tips</span><span class="p">()</span>
  92. <span class="n">is_setup_pro_user_monitoring</span> <span class="o">=</span> <span class="n">setup_pro_user_monitoring</span><span class="p">()</span>
  93. <span class="k">if</span> <span class="n">is_setup_crash_tips</span> <span class="ow">or</span> <span class="n">is_setup_pro_user_monitoring</span><span class="p">:</span> <span class="c1"># We don&#39;t want to wrap sys.excepthook when not required</span>
  94. <span class="n">sys</span><span class="o">.</span><span class="n">excepthook</span> <span class="o">=</span> <span class="n">register_exceptions</span><span class="p">(</span><span class="n">sys</span><span class="o">.</span><span class="n">excepthook</span><span class="p">)</span></div>
  95. </pre></div>
  96. </div>
  97. </div>
  98. <footer>
  99. <hr/>
  100. <div role="contentinfo">
  101. <p>&#169; Copyright 2021, SuperGradients team.</p>
  102. </div>
  103. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  104. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  105. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  106. </footer>
  107. </div>
  108. </div>
  109. </section>
  110. </div>
  111. <script>
  112. jQuery(function () {
  113. SphinxRtdTheme.Navigation.enable(true);
  114. });
  115. </script>
  116. </body>
  117. </html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.common.data_connection.s3_connector &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.common.data_connection.s3_connector &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -98,11 +100,11 @@
 <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">ILogger</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">ILogger</span>
 
 
 
 
-<div class="viewcode-block" id="KeyNotExistInBucketError"><a class="viewcode-back" href="../../../../super_gradients.common.data_connection.html#super_gradients.common.KeyNotExistInBucketError">[docs]</a><span class="k">class</span> <span class="nc">KeyNotExistInBucketError</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
-    <span class="k">pass</span></div>
+<span class="k">class</span> <span class="nc">KeyNotExistInBucketError</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
+    <span class="k">pass</span>
 
 
 
 
-<div class="viewcode-block" id="S3Connector"><a class="viewcode-back" href="../../../../super_gradients.common.data_connection.html#super_gradients.common.S3Connector">[docs]</a><span class="k">class</span> <span class="nc">S3Connector</span><span class="p">(</span><span class="n">ILogger</span><span class="p">):</span>
+<div class="viewcode-block" id="S3Connector"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.S3Connector">[docs]</a><span class="k">class</span> <span class="nc">S3Connector</span><span class="p">(</span><span class="n">ILogger</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    S3Connector - S3 Connection Manager</span>
 <span class="sd">    S3Connector - S3 Connection Manager</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
@@ -458,7 +460,7 @@
                                                          <span class="n">ExpiresIn</span><span class="o">=</span><span class="n">expiration</span><span class="p">)</span>
                                                          <span class="n">ExpiresIn</span><span class="o">=</span><span class="n">expiration</span><span class="p">)</span>
         <span class="k">return</span> <span class="n">response</span>
         <span class="k">return</span> <span class="n">response</span>
 
 
-<div class="viewcode-block" id="S3Connector.convert_content_length_to_mb"><a class="viewcode-back" href="../../../../super_gradients.common.data_connection.html#super_gradients.common.S3Connector.convert_content_length_to_mb">[docs]</a>    <span class="nd">@staticmethod</span>
+<div class="viewcode-block" id="S3Connector.convert_content_length_to_mb"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.S3Connector.convert_content_length_to_mb">[docs]</a>    <span class="nd">@staticmethod</span>
     <span class="k">def</span> <span class="nf">convert_content_length_to_mb</span><span class="p">(</span><span class="n">content_length</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">convert_content_length_to_mb</span><span class="p">(</span><span class="n">content_length</span><span class="p">):</span>
         <span class="k">return</span> <span class="nb">round</span><span class="p">(</span><span class="nb">float</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">content_length</span> <span class="o">/</span> <span class="p">(</span><span class="mf">1e+6</span><span class="p">)</span><span class="si">:</span><span class="s1">2f</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">),</span> <span 
         <span class="k">return</span> <span class="nb">round</span><span class="p">(</span><span class="nb">float</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">content_length</span> <span class="o">/</span> <span class="p">(</span><span class="mf">1e+6</span><span class="p">)</span><span class="si">:</span><span class="s1">2f</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">),</span> <span 
 
 
@@ -524,4 +526,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.common.data_interface.adnn_model_repository_data_interface &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.common.data_interface.adnn_model_repository_data_interface &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -92,11 +94,11 @@
 <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">ILogger</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">ILogger</span>
 
 
 
 
-<div class="viewcode-block" id="ModelCheckpointNotFoundException"><a class="viewcode-back" href="../../../../super_gradients.common.data_interface.html#super_gradients.common.ModelCheckpointNotFoundException">[docs]</a><span class="k">class</span> <span class="nc">ModelCheckpointNotFoundException</span><span class="p">(</span><span class="ne">RuntimeError</span><span class="p">):</span>
-    <span class="k">pass</span></div>
+<span class="k">class</span> <span class="nc">ModelCheckpointNotFoundException</span><span class="p">(</span><span class="ne">RuntimeError</span><span class="p">):</span>
+    <span class="k">pass</span>
 
 
 
 
-<div class="viewcode-block" id="ADNNModelRepositoryDataInterfaces"><a class="viewcode-back" href="../../../../super_gradients.common.data_interface.html#super_gradients.common.ADNNModelRepositoryDataInterfaces">[docs]</a><span class="k">class</span> <span class="nc">ADNNModelRepositoryDataInterfaces</span><span class="p">(</span><span class="n">ILogger</span><span class="p">):</span>
+<div class="viewcode-block" id="ADNNModelRepositoryDataInterfaces"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.ADNNModelRepositoryDataInterfaces">[docs]</a><span class="k">class</span> <span class="nc">ADNNModelRepositoryDataInterfaces</span><span class="p">(</span><span class="n">ILogger</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    ResearchModelRepositoryDataInterface</span>
 <span class="sd">    ResearchModelRepositoryDataInterface</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
@@ -293,4 +295,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.common.data_interface.dataset_data_interface &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.common.data_interface.dataset_data_interface &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -92,7 +94,7 @@
 <span class="kn">import</span> <span class="nn">zipfile</span>
 <span class="kn">import</span> <span class="nn">zipfile</span>
 
 
 
 
-<div class="viewcode-block" id="DatasetDataInterface"><a class="viewcode-back" href="../../../../super_gradients.common.data_interface.html#super_gradients.common.DatasetDataInterface">[docs]</a><span class="k">class</span> <span class="nc">DatasetDataInterface</span><span class="p">:</span>
+<div class="viewcode-block" id="DatasetDataInterface"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.DatasetDataInterface">[docs]</a><span class="k">class</span> <span class="nc">DatasetDataInterface</span><span class="p">:</span>
 
 
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">env</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">data_connection_source</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;s3&#39;</span><span class="p">):</span>
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">env</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">data_connection_source</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;s3&#39;</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
@@ -167,4 +169,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.common.data_types.enum.deep_learning_task &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.common.data_types.enum.deep_learning_task &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
+        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
+        <script src="../../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -89,7 +91,7 @@
 <span></span><span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
 <span></span><span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
 
 
 
 
-<div class="viewcode-block" id="DeepLearningTask"><a class="viewcode-back" href="../../../../../super_gradients.common.data_types.enum.html#super_gradients.common.DeepLearningTask">[docs]</a><span class="k">class</span> <span class="nc">DeepLearningTask</span><span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Enum</span><span class="p">):</span>
+<div class="viewcode-block" id="DeepLearningTask"><a class="viewcode-back" href="../../../../../super_gradients.common.html#super_gradients.common.DeepLearningTask">[docs]</a><span class="k">class</span> <span class="nc">DeepLearningTask</span><span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Enum</span><span class="p">):</span>
     <span class="n">CLASSIFICATION</span> <span class="o">=</span> <span class="s1">&#39;classification&#39;</span>
     <span class="n">CLASSIFICATION</span> <span class="o">=</span> <span class="s1">&#39;classification&#39;</span>
     <span class="n">SEMANTIC_SEGMENTATION</span> <span class="o">=</span> <span class="s1">&#39;semantic_segmentation&#39;</span>
     <span class="n">SEMANTIC_SEGMENTATION</span> <span class="o">=</span> <span class="s1">&#39;semantic_segmentation&#39;</span>
     <span class="n">OBJECT_DETECTION</span> <span class="o">=</span> <span class="s1">&#39;object_detection&#39;</span>
     <span class="n">OBJECT_DETECTION</span> <span class="o">=</span> <span class="s1">&#39;object_detection&#39;</span>
@@ -126,4 +128,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.common.data_types.enum.evaluation_type &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.common.data_types.enum.evaluation_type &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
+        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
+        <script src="../../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -89,11 +91,11 @@
 <span></span><span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
 <span></span><span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
 
 
 
 
-<div class="viewcode-block" id="EvaluationType"><a class="viewcode-back" href="../../../../../super_gradients.common.data_types.enum.html#super_gradients.common.EvaluationType">[docs]</a><span class="k">class</span> <span class="nc">EvaluationType</span><span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Enum</span><span class="p">):</span>
+<div class="viewcode-block" id="EvaluationType"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.common.EvaluationType">[docs]</a><span class="k">class</span> <span class="nc">EvaluationType</span><span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Enum</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    EvaluationType</span>
 <span class="sd">    EvaluationType</span>
 
 
-<span class="sd">    Passed to SgModel.evaluate(..), and controls which phase callbacks should be triggered (if at all).</span>
+<span class="sd">    Passed to Trainer.evaluate(..), and controls which phase callbacks should be triggered (if at all).</span>
 
 
 <span class="sd">        Attributes:</span>
 <span class="sd">        Attributes:</span>
 <span class="sd">            TEST</span>
 <span class="sd">            TEST</span>
@@ -131,4 +133,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.common.data_types.enum.multi_gpu_mode &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.common.data_types.enum.multi_gpu_mode &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
+        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
+        <script src="../../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -87,9 +89,10 @@
              
              
   <h1>Source code for super_gradients.common.data_types.enum.multi_gpu_mode</h1><div class="highlight"><pre>
   <h1>Source code for super_gradients.common.data_types.enum.multi_gpu_mode</h1><div class="highlight"><pre>
 <span></span><span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
 <span></span><span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
+<span class="kn">import</span> <span class="nn">stringcase</span>
 
 
 
 
-<div class="viewcode-block" id="MultiGPUMode"><a class="viewcode-back" href="../../../../../super_gradients.training.sg_model.html#super_gradients.common.MultiGPUMode">[docs]</a><span class="k">class</span> <span class="nc">MultiGPUMode</span><span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Enum</span><span class="p">):</span>
+<div class="viewcode-block" id="MultiGPUMode"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.common.MultiGPUMode">[docs]</a><span class="k">class</span> <span class="nc">MultiGPUMode</span><span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Enum</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    MultiGPUMode</span>
 <span class="sd">    MultiGPUMode</span>
 
 
@@ -98,10 +101,26 @@
 <span class="sd">            DATA_PARALLEL             - Multiple GPUs, Synchronous</span>
 <span class="sd">            DATA_PARALLEL             - Multiple GPUs, Synchronous</span>
 <span class="sd">            DISTRIBUTED_DATA_PARALLEL - Multiple GPUs, Asynchronous</span>
 <span class="sd">            DISTRIBUTED_DATA_PARALLEL - Multiple GPUs, Asynchronous</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
-    <span class="n">OFF</span> <span class="o">=</span> <span class="s1">&#39;Off&#39;</span>
-    <span class="n">DATA_PARALLEL</span> <span class="o">=</span> <span class="s1">&#39;DP&#39;</span>
-    <span class="n">DISTRIBUTED_DATA_PARALLEL</span> <span class="o">=</span> <span class="s1">&#39;DDP&#39;</span>
-    <span class="n">AUTO</span> <span class="o">=</span> <span class="s2">&quot;AUTO&quot;</span></div>
+
+    <span class="n">OFF</span> <span class="o">=</span> <span class="s2">&quot;Off&quot;</span>
+    <span class="n">DATA_PARALLEL</span> <span class="o">=</span> <span class="s2">&quot;DP&quot;</span>
+    <span class="n">DISTRIBUTED_DATA_PARALLEL</span> <span class="o">=</span> <span class="s2">&quot;DDP&quot;</span>
+    <span class="n">AUTO</span> <span class="o">=</span> <span class="s2">&quot;AUTO&quot;</span>
+
+<div class="viewcode-block" id="MultiGPUMode.dict"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.common.MultiGPUMode.dict">[docs]</a>    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">dict</span><span class="p">(</span><span class="bp">cls</span><span class="p">):</span>
+        <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">        return dictionary mapping from the mode name (in call string cases) to the enum value</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+        <span class="n">out_dict</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
+        <span class="k">for</span> <span class="n">mode</span> <span class="ow">in</span> <span class="n">MultiGPUMode</span><span class="p">:</span>
+            <span class="n">out_dict</span><span class="p">[</span><span class="n">mode</span><span class="o">.</span><span class="n">value</span><span class="p">]</span> <span class="o">=</span> <span class="n">mode</span>
+            <span class="n">out_dict</span><span class="p">[</span><span class="n">mode</span><span class="o">.</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="n">mode</span>
+            <span class="n">out_dict</span><span class="p">[</span><span class="n">stringcase</span><span class="o">.</span><span class="n">capitalcase</span><span class="p">(</span><span class="n">mode</span><span class="o">.</span><span class="n">name</span><span class="p">)]</span> <span class="o">=</span> <span class="n">mode</span>
+            <span class="n">out_dict</span><span class="p">[</span><span class="n">stringcase</span><span class="o">.</span><span class="n">camelcase</span><span class="p">(</span><span class="n">mode</span><span class="o">.</span><span class="n">name</span><span class="p">)]</span> <span class="o">=</span> <span class="n">mode</span>
+            <span class="n">out_dict</span><span class="p">[</span><span class="n">stringcase</span><span class="o">.</span><span class="n">lowercase</span><span class="p">(</span><span class="n">mode</span><span class="o">.</span><span class="n">name</span><span class="p">)]</span> <span class="o">=</span> <span class="n">mode</span>
+        <span class="n">out_dict</span><span class="p">[</span><span class="kc">False</span><span class="p">]</span> <span class="o">=</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">OFF</span>
+        <span class="k">return</span> <span class="n">out_dict</span></div></div>
 </pre></div>
 </pre></div>
 
 
            </div>
            </div>
@@ -131,4 +150,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.common.data_types.enum.strict_load &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.common.data_types.enum.strict_load &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
+        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
+        <script src="../../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -89,18 +91,19 @@
 <span></span><span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
 <span></span><span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
 
 
 
 
-<div class="viewcode-block" id="StrictLoad"><a class="viewcode-back" href="../../../../../super_gradients.training.sg_model.html#super_gradients.common.StrictLoad">[docs]</a><span class="k">class</span> <span class="nc">StrictLoad</span><span class="p">(</span><span class="n">Enum</span><span class="p">):</span>
+<div class="viewcode-block" id="StrictLoad"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.common.StrictLoad">[docs]</a><span class="k">class</span> <span class="nc">StrictLoad</span><span class="p">(</span><span class="n">Enum</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Wrapper for adding more functionality to torch&#39;s strict_load parameter in load_state_dict().</span>
 <span class="sd">    Wrapper for adding more functionality to torch&#39;s strict_load parameter in load_state_dict().</span>
-<span class="sd">        Attributes:</span>
-<span class="sd">            OFF              - Native torch &quot;strict_load = off&quot; behaviour. See nn.Module.load_state_dict() documentation for more details.</span>
-<span class="sd">            ON               - Native torch &quot;strict_load = on&quot; behaviour. See nn.Module.load_state_dict() documentation for more details.</span>
-<span class="sd">            NO_KEY_MATCHING  - Allows the usage of SuperGradient&#39;s adapt_checkpoint function, which loads a checkpoint by matching each</span>
-<span class="sd">                               layer&#39;s shapes (and bypasses the strict matching of the names of each layer (ie: disregards the state_dict key matching)).</span>
+<span class="sd">    Attributes:</span>
+<span class="sd">        OFF              - Native torch &quot;strict_load = off&quot; behaviour. See nn.Module.load_state_dict() documentation for more details.</span>
+<span class="sd">        ON               - Native torch &quot;strict_load = on&quot; behaviour. See nn.Module.load_state_dict() documentation for more details.</span>
+<span class="sd">        NO_KEY_MATCHING  - Allows the usage of SuperGradient&#39;s adapt_checkpoint function, which loads a checkpoint by matching each</span>
+<span class="sd">                           layer&#39;s shapes (and bypasses the strict matching of the names of each layer (ie: disregards the state_dict key matching)).</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
+
     <span class="n">OFF</span> <span class="o">=</span> <span class="kc">False</span>
     <span class="n">OFF</span> <span class="o">=</span> <span class="kc">False</span>
     <span class="n">ON</span> <span class="o">=</span> <span class="kc">True</span>
     <span class="n">ON</span> <span class="o">=</span> <span class="kc">True</span>
-    <span class="n">NO_KEY_MATCHING</span> <span class="o">=</span> <span class="s1">&#39;no_key_matching&#39;</span></div>
+    <span class="n">NO_KEY_MATCHING</span> <span class="o">=</span> <span class="s2">&quot;no_key_matching&quot;</span></div>
 </pre></div>
 </pre></div>
 
 
            </div>
            </div>
@@ -130,4 +133,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.common.data_types.enum.upsample_mode &mdash; SuperGradients 3.0.3 documentation</title>
  7. <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
  10. <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
  11. <!--[if lt IE 9]>
  12. <script src="../../../../../_static/js/html5shiv.min.js"></script>
  13. <![endif]-->
  14. <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
  15. <script src="../../../../../_static/jquery.js"></script>
  16. <script src="../../../../../_static/underscore.js"></script>
  17. <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
  18. <script src="../../../../../_static/doctools.js"></script>
  19. <script src="../../../../../_static/sphinx_highlight.js"></script>
  20. <script src="../../../../../_static/js/theme.js"></script>
  21. <link rel="index" title="Index" href="../../../../../genindex.html" />
  22. <link rel="search" title="Search" href="../../../../../search.html" />
  23. </head>
  24. <body class="wy-body-for-nav">
  25. <div class="wy-grid-for-nav">
  26. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  27. <div class="wy-side-scroll">
  28. <div class="wy-side-nav-search" >
  29. <a href="../../../../../index.html" class="icon icon-home"> SuperGradients
  30. </a>
  31. <div role="search">
  32. <form id="rtd-search-form" class="wy-form" action="../../../../../search.html" method="get">
  33. <input type="text" name="q" placeholder="Search docs" />
  34. <input type="hidden" name="check_keywords" value="yes" />
  35. <input type="hidden" name="area" value="default" />
  36. </form>
  37. </div>
  38. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  39. <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
  40. <ul>
  41. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
  42. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
  45. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
  46. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
  47. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
  57. </ul>
  58. <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
  59. <ul>
  60. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
  61. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
  62. </ul>
  63. </div>
  64. </div>
  65. </nav>
  66. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  67. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  68. <a href="../../../../../index.html">SuperGradients</a>
  69. </nav>
  70. <div class="wy-nav-content">
  71. <div class="rst-content">
  72. <div role="navigation" aria-label="Page navigation">
  73. <ul class="wy-breadcrumbs">
  74. <li><a href="../../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  75. <li><a href="../../../../index.html">Module code</a> &raquo;</li>
  76. <li>super_gradients.common.data_types.enum.upsample_mode</li>
  77. <li class="wy-breadcrumbs-aside">
  78. </li>
  79. </ul>
  80. <hr/>
  81. </div>
  82. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  83. <div itemprop="articleBody">
  84. <h1>Source code for super_gradients.common.data_types.enum.upsample_mode</h1><div class="highlight"><pre>
  85. <span></span><span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
  86. <div class="viewcode-block" id="UpsampleMode"><a class="viewcode-back" href="../../../../../super_gradients.common.html#super_gradients.common.UpsampleMode">[docs]</a><span class="k">class</span> <span class="nc">UpsampleMode</span><span class="p">(</span><span class="n">Enum</span><span class="p">):</span>
  87. <span class="n">NEAREST</span> <span class="o">=</span> <span class="s2">&quot;nearest&quot;</span>
  88. <span class="n">BILINEAR</span> <span class="o">=</span> <span class="s2">&quot;bilinear&quot;</span>
  89. <span class="n">BICUBIC</span> <span class="o">=</span> <span class="s2">&quot;bicubic&quot;</span>
  90. <span class="n">SNPE_BILINEAR</span> <span class="o">=</span> <span class="s2">&quot;snpe_bilinear&quot;</span></div>
  91. </pre></div>
  92. </div>
  93. </div>
  94. <footer>
  95. <hr/>
  96. <div role="contentinfo">
  97. <p>&#169; Copyright 2021, SuperGradients team.</p>
  98. </div>
  99. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  100. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  101. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  102. </footer>
  103. </div>
  104. </div>
  105. </section>
  106. </div>
  107. <script>
  108. jQuery(function () {
  109. SphinxRtdTheme.Navigation.enable(true);
  110. });
  111. </script>
  112. </body>
  113. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.common.decorators.deci_logger &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.common.decorators.deci_logger</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.common.decorators.deci_logger</h1><div class="highlight"><pre>
  83. <div class="viewcode-block" id="deci_func_logger"><a class="viewcode-back" href="../../../../super_gradients.common.decorators.html#super_gradients.common.decorators.deci_logger.deci_func_logger">[docs]</a><span></span><span class="k">def</span> <span class="nf">deci_func_logger</span><span class="p">(</span><span class="n">_func</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <span class="n">name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;abstract_decorator&#39;</span><span class="p">):</span>
  84. <span class="sd">&quot;&quot;&quot;</span>
  85. <span class="sd"> This decorator is used to wrap our functions with logs.</span>
  86. <span class="sd"> It will log every enter and exit of the functon with the equivalent parameters as extras.</span>
  87. <span class="sd"> It will also log exceptions that raises in the function.</span>
  88. <span class="sd"> It will also log the exception time of the function.</span>
  89. <span class="sd"> How it works:`</span>
  90. <span class="sd"> First it will check if the decorator called with name keyword.</span>
  91. <span class="sd"> If so it will return a new decorator that its logger is the name parameter.</span>
  92. <span class="sd"> If not it will return a new decorator that its logger is the wrapped function name.</span>
  93. <span class="sd"> Then the return decorator will return a new function that warps the original function with the new logs.</span>
  94. <span class="sd"> For further understanding advise real-python &quot;fancy decorators documentation&quot;</span>
  95. <span class="sd"> Args:</span>
  96. <span class="sd"> _func (): used when called without name specify. dont pass it directly</span>
  97. <span class="sd"> name (): The name of the logger to save logs by.</span>
  98. <span class="sd"> Returns:</span>
  99. <span class="sd"> a decorator that wraps function with logs logic.</span>
  100. <span class="sd"> &quot;&quot;&quot;</span>
  101. <span class="c1"># TODO: Not Working - Breaks the code, tests does not pass (s3 connector, platform...)</span>
  102. <span class="c1"># TODO: Fix problem with ExplicitParamValidation error (arguments not passed)</span>
  103. <span class="c1"># TODO: Run ALL test suite of deci2 (NOT circieCI test suite, but ALL the tests under tests folders)</span>
  104. <span class="c1"># TODO: Delete/Update all failing tests.</span>
  105. <span class="c1"># def deci_logger_decorator(fn):</span>
  106. <span class="c1">#</span>
  107. <span class="c1"># @functools.wraps(fn)</span>
  108. <span class="c1"># def wrapper_func(*args, **kwargs):</span>
  109. <span class="c1"># try:</span>
  110. <span class="c1">#</span>
  111. <span class="c1"># try:</span>
  112. <span class="c1"># logger.debug(f&quot;Start: {fn.__name__}&quot;, extra={&quot;args&quot;: args, &quot;kwargs&quot;: kwargs})</span>
  113. <span class="c1"># time1 = time.perf_counter()</span>
  114. <span class="c1"># except Exception:</span>
  115. <span class="c1"># # failed to write log - continue.</span>
  116. <span class="c1"># pass</span>
  117. <span class="c1">#</span>
  118. <span class="c1"># result = fn(*args, **kwargs)</span>
  119. <span class="c1">#</span>
  120. <span class="c1"># try:</span>
  121. <span class="c1"># time2 = time.perf_counter()</span>
  122. <span class="c1"># logger.debug(f&quot;End: {fn.__name__}&quot;,</span>
  123. <span class="c1"># extra={&#39;duration&#39;: (time2 - time1) * 1000.0, &#39;return_value&#39;: result})</span>
  124. <span class="c1"># except Exception:</span>
  125. <span class="c1"># # failed to write log - continue.</span>
  126. <span class="c1"># pass</span>
  127. <span class="c1">#</span>
  128. <span class="c1"># return result</span>
  129. <span class="c1">#</span>
  130. <span class="c1"># except Exception as ex:</span>
  131. <span class="c1"># # This exception was raised from inside the function call</span>
  132. <span class="c1"># logger.error(f&quot;Exception: {ex}&quot;, exc_info=ex)</span>
  133. <span class="c1"># raise ex</span>
  134. <span class="c1">#</span>
  135. <span class="c1"># return wrapper_func</span>
  136. <span class="c1"># if _func is None:</span>
  137. <span class="c1"># logger = get_logger(name)</span>
  138. <span class="c1"># return deci_logger_decorator</span>
  139. <span class="c1"># else:</span>
  140. <span class="c1"># logger = get_logger(_func.__name__)</span>
  141. <span class="c1"># return deci_logger_decorator(_func)</span>
  142. <span class="k">return</span> <span class="n">_func</span></div>
  143. <div class="viewcode-block" id="deci_class_logger"><a class="viewcode-back" href="../../../../super_gradients.common.decorators.html#super_gradients.common.decorators.deci_logger.deci_class_logger">[docs]</a><span class="k">def</span> <span class="nf">deci_class_logger</span><span class="p">():</span>
  144. <span class="sd">&quot;&quot;&quot;</span>
  145. <span class="sd"> This decorator wraps every class method with deci_func_logger decorator.</span>
  146. <span class="sd"> It works by checking if class method is callable and if so it will set a new decorated method as the same method name.</span>
  147. <span class="sd"> &quot;&quot;&quot;</span>
  148. <span class="k">def</span> <span class="nf">wrapper</span><span class="p">(</span><span class="bp">cls</span><span class="p">):</span>
  149. <span class="c1"># TODO: Not Working - Breaks the code, tests does not pass (s3 connector, platform...)</span>
  150. <span class="c1"># TODO: Fix problem with ExplicitParamValidation error (arguments not passed)</span>
  151. <span class="c1"># TODO: Run ALL test suite of deci2 (NOT circieCI test suite, but ALL the tests under tests folders)</span>
  152. <span class="c1"># TODO: Delete/Update all failing tests.</span>
  153. <span class="c1"># for attr in cls.__dict__:</span>
  154. <span class="c1"># if callable(getattr(cls, attr)) and attr != &#39;__init__&#39;:</span>
  155. <span class="c1"># decorated_function = deci_func_logger(name=cls.__name__)(getattr(cls, attr))</span>
  156. <span class="c1"># if type(cls.__dict__[attr]) is staticmethod:</span>
  157. <span class="c1"># decorated_function = staticmethod(decorated_function)</span>
  158. <span class="c1"># setattr(cls, attr, decorated_function)</span>
  159. <span class="k">return</span> <span class="bp">cls</span>
  160. <span class="k">return</span> <span class="n">wrapper</span></div>
  161. </pre></div>
  162. </div>
  163. </div>
  164. <footer>
  165. <hr/>
  166. <div role="contentinfo">
  167. <p>&#169; Copyright 2021, SuperGradients team.</p>
  168. </div>
  169. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  170. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  171. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  172. </footer>
  173. </div>
  174. </div>
  175. </section>
  176. </div>
  177. <script>
  178. jQuery(function () {
  179. SphinxRtdTheme.Navigation.enable(true);
  180. });
  181. </script>
  182. </body>
  183. </html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.common.decorators.explicit_params_validator &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.common.decorators.explicit_params_validator &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -160,7 +162,7 @@
 
 
 
 
 <span class="c1"># WRAPS THE RETRY DECORATOR CLASS TO ENABLE CALLING WITHOUT PARAMS</span>
 <span class="c1"># WRAPS THE RETRY DECORATOR CLASS TO ENABLE CALLING WITHOUT PARAMS</span>
-<div class="viewcode-block" id="explicit_params_validation"><a class="viewcode-back" href="../../../../super_gradients.common.decorators.html#super_gradients.common.explicit_params_validation">[docs]</a><span class="k">def</span> <span class="nf">explicit_params_validation</span><span class="p">(</span><span class="n">function</span><span class="p">:</span> <span class="n">Callable</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">validation_ty
+<div class="viewcode-block" id="explicit_params_validation"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.explicit_params_validation">[docs]</a><span class="k">def</span> <span class="nf">explicit_params_validation</span><span class="p">(</span><span class="n">function</span><span class="p">:</span> <span class="n">Callable</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">validation_type</span><s
     <span class="k">if</span> <span class="n">function</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
     <span class="k">if</span> <span class="n">function</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
         <span class="k">return</span> <span class="n">_ExplicitParamsValidator</span><span class="p">(</span><span class="n">function</span><span class="o">=</span><span class="n">function</span><span class="p">)</span>
         <span class="k">return</span> <span class="n">_ExplicitParamsValidator</span><span class="p">(</span><span class="n">function</span><span class="o">=</span><span class="n">function</span><span class="p">)</span>
     <span class="k">else</span><span class="p">:</span>
     <span class="k">else</span><span class="p">:</span>
@@ -197,4 +199,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.common.decorators.singleton &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.common.decorators.singleton &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -86,7 +88,7 @@
            <div itemprop="articleBody">
            <div itemprop="articleBody">
              
              
   <h1>Source code for super_gradients.common.decorators.singleton</h1><div class="highlight"><pre>
   <h1>Source code for super_gradients.common.decorators.singleton</h1><div class="highlight"><pre>
-<div class="viewcode-block" id="SingletonMeta"><a class="viewcode-back" href="../../../../super_gradients.common.decorators.html#super_gradients.common.SingletonMeta">[docs]</a><span></span><span class="k">class</span> <span class="nc">SingletonMeta</span><span class="p">(</span><span class="nb">type</span><span class="p">):</span>
+<span></span><span class="k">class</span> <span class="nc">SingletonMeta</span><span class="p">(</span><span class="nb">type</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    A Singleton meta class.</span>
 <span class="sd">    A Singleton meta class.</span>
 <span class="sd">    A class that derives from this class will have only 1 instance of that type for the process.</span>
 <span class="sd">    A class that derives from this class will have only 1 instance of that type for the process.</span>
@@ -97,7 +99,7 @@
     <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
     <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
         <span class="k">if</span> <span class="bp">cls</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">cls</span><span class="o">.</span><span class="n">_instances</span><span class="p">:</span>
         <span class="k">if</span> <span class="bp">cls</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">cls</span><span class="o">.</span><span class="n">_instances</span><span class="p">:</span>
             <span class="bp">cls</span><span class="o">.</span><span class="n">_instances</span><span class="p">[</span><span class="bp">cls</span><span class="p">]</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">SingletonMeta</span><span class="p">,</span> <span class="bp">cls</span><span class="p">)</span><span class="o">.</span><span class="fm">__call__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span 
             <span class="bp">cls</span><span class="o">.</span><span class="n">_instances</span><span class="p">[</span><span class="bp">cls</span><span class="p">]</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">SingletonMeta</span><span class="p">,</span> <span class="bp">cls</span><span class="p">)</span><span class="o">.</span><span class="fm">__call__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span 
-        <span class="k">return</span> <span class="bp">cls</span><span class="o">.</span><span class="n">_instances</span><span class="p">[</span><span class="bp">cls</span><span class="p">]</span></div>
+        <span class="k">return</span> <span class="bp">cls</span><span class="o">.</span><span class="n">_instances</span><span class="p">[</span><span class="bp">cls</span><span class="p">]</span>
 
 
 
 
 <span class="k">class</span> <span class="nc">_SingletonWrapper</span><span class="p">:</span>
 <span class="k">class</span> <span class="nc">_SingletonWrapper</span><span class="p">:</span>
@@ -117,7 +119,7 @@
         <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_instance</span>
         <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_instance</span>
 
 
 
 
-<div class="viewcode-block" id="singleton"><a class="viewcode-back" href="../../../../super_gradients.common.decorators.html#super_gradients.common.singleton">[docs]</a><span class="k">def</span> <span class="nf">singleton</span><span class="p">(</span><span class="bp">cls</span><span class="p">):</span>
+<div class="viewcode-block" id="singleton"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.singleton">[docs]</a><span class="k">def</span> <span class="nf">singleton</span><span class="p">(</span><span class="bp">cls</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    A singleton decorator. Returns a wrapper objects. A call on that object</span>
 <span class="sd">    A singleton decorator. Returns a wrapper objects. A call on that object</span>
 <span class="sd">    returns a single instance object of decorated class. Use the __wrapped__</span>
 <span class="sd">    returns a single instance object of decorated class. Use the __wrapped__</span>
@@ -153,4 +155,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.common.environment.env_helpers &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.common.environment.env_helpers &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -87,38 +89,56 @@
              
              
   <h1>Source code for super_gradients.common.environment.env_helpers</h1><div class="highlight"><pre>
   <h1>Source code for super_gradients.common.environment.env_helpers</h1><div class="highlight"><pre>
 <span></span><span class="kn">import</span> <span class="nn">argparse</span>
 <span></span><span class="kn">import</span> <span class="nn">argparse</span>
+<span class="kn">import</span> <span class="nn">importlib</span>
 <span class="kn">import</span> <span class="nn">os</span>
 <span class="kn">import</span> <span class="nn">os</span>
+<span class="kn">import</span> <span class="nn">socket</span>
 <span class="kn">import</span> <span class="nn">sys</span>
 <span class="kn">import</span> <span class="nn">sys</span>
 <span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">wraps</span>
 <span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">wraps</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span>
+
+<span class="kn">from</span> <span class="nn">omegaconf</span> <span class="kn">import</span> <span class="n">OmegaConf</span>
 
 
 <span class="kn">from</span> <span class="nn">super_gradients.common.environment</span> <span class="kn">import</span> <span class="n">environment_config</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.environment</span> <span class="kn">import</span> <span class="n">environment_config</span>
 
 
 
 
-<div class="viewcode-block" id="TerminalColours"><a class="viewcode-back" href="../../../../super_gradients.common.environment.html#super_gradients.common.TerminalColours">[docs]</a><span class="k">class</span> <span class="nc">TerminalColours</span><span class="p">:</span>
+<span class="k">class</span> <span class="nc">TerminalColours</span><span class="p">:</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Usage: https://stackoverflow.com/questions/287871/how-to-print-colored-text-in-python?page=1&amp;tab=votes#tab-top</span>
 <span class="sd">    Usage: https://stackoverflow.com/questions/287871/how-to-print-colored-text-in-python?page=1&amp;tab=votes#tab-top</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
-    <span class="n">HEADER</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="se">\033</span><span class="s1">[95m&#39;</span>
-    <span class="n">OKBLUE</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="se">\033</span><span class="s1">[94m&#39;</span>
-    <span class="n">OKCYAN</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="se">\033</span><span class="s1">[96m&#39;</span>
-    <span class="n">OKGREEN</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="se">\033</span><span class="s1">[92m&#39;</span>
-    <span class="n">WARNING</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="se">\033</span><span class="s1">[93m&#39;</span>
-    <span class="n">FAIL</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="se">\033</span><span class="s1">[91m&#39;</span>
-    <span class="n">ENDC</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="se">\033</span><span class="s1">[0m&#39;</span>
-    <span class="n">BOLD</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="se">\033</span><span class="s1">[1m&#39;</span>
-    <span class="n">UNDERLINE</span> <span class="o">=</span> <span class="s1">&#39;</span><span class="se">\033</span><span class="s1">[4m&#39;</span></div>
-
-
-<div class="viewcode-block" id="ColouredTextFormatter"><a class="viewcode-back" href="../../../../super_gradients.common.environment.html#super_gradients.common.ColouredTextFormatter">[docs]</a><span class="k">class</span> <span class="nc">ColouredTextFormatter</span><span class="p">:</span>
-<div class="viewcode-block" id="ColouredTextFormatter.print_coloured_text"><a class="viewcode-back" href="../../../../super_gradients.common.environment.html#super_gradients.common.ColouredTextFormatter.print_coloured_text">[docs]</a>    <span class="nd">@staticmethod</span>
+
+    <span class="n">HEADER</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="se">\033</span><span class="s2">[95m&quot;</span>
+    <span class="n">OKBLUE</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="se">\033</span><span class="s2">[94m&quot;</span>
+    <span class="n">OKCYAN</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="se">\033</span><span class="s2">[96m&quot;</span>
+    <span class="n">OKGREEN</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="se">\033</span><span class="s2">[92m&quot;</span>
+    <span class="n">WARNING</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="se">\033</span><span class="s2">[93m&quot;</span>
+    <span class="n">FAIL</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="se">\033</span><span class="s2">[91m&quot;</span>
+    <span class="n">ENDC</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="se">\033</span><span class="s2">[0m&quot;</span>
+    <span class="n">BOLD</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="se">\033</span><span class="s2">[1m&quot;</span>
+    <span class="n">UNDERLINE</span> <span class="o">=</span> <span class="s2">&quot;</span><span class="se">\033</span><span class="s2">[4m&quot;</span>
+
+
+<span class="k">class</span> <span class="nc">ColouredTextFormatter</span><span class="p">:</span>
+    <span class="nd">@staticmethod</span>
     <span class="k">def</span> <span class="nf">print_coloured_text</span><span class="p">(</span><span class="n">text</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">colour</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">print_coloured_text</span><span class="p">(</span><span class="n">text</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">colour</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        Prints a text with colour ascii characters.</span>
 <span class="sd">        Prints a text with colour ascii characters.</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
-        <span class="k">return</span> <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="n">colour</span><span class="p">,</span> <span class="n">text</span><span class="p">,</span> <span class="n">TerminalColours</span><span class="o">.</span><span class="n">ENDC</span><span class="p">]))</span></div></div>
+        <span class="k">return</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;&quot;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="n">colour</span><span class="p">,</span> <span class="n">text</span><span class="p">,</span> <span class="n">TerminalColours</span><span class="o">.</span><span class="n">ENDC</span><span class="p">]))</span>
+
+
+<span class="k">def</span> <span class="nf">get_cls</span><span class="p">(</span><span class="n">cls_path</span><span class="p">):</span>
+    <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">    A resolver for Hydra/OmegaConf to allow getting a class instead on an instance.</span>
+<span class="sd">    usage:</span>
+<span class="sd">    class_of_optimizer: ${class:torch.optim.Adam}</span>
+<span class="sd">    &quot;&quot;&quot;</span>
+    <span class="n">module</span> <span class="o">=</span> <span class="s2">&quot;.&quot;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">cls_path</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;.&quot;</span><span class="p">)[:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
+    <span class="n">name</span> <span class="o">=</span> <span class="n">cls_path</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;.&quot;</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
+    <span class="n">importlib</span><span class="o">.</span><span class="n">import_module</span><span class="p">(</span><span class="n">module</span><span class="p">)</span>
+    <span class="k">return</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">sys</span><span class="o">.</span><span class="n">modules</span><span class="p">[</span><span class="n">module</span><span class="p">],</span> <span class="n">name</span><span class="p">)</span>
 
 
 
 
-<div class="viewcode-block" id="get_environ_as_type"><a class="viewcode-back" href="../../../../super_gradients.common.environment.html#super_gradients.common.get_environ_as_type">[docs]</a><span class="k">def</span> <span class="nf">get_environ_as_type</span><span class="p">(</span><span class="n">environment_variable_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">None</span><span cl
+<span class="k">def</span> <span class="nf">get_environ_as_type</span><span class="p">(</span><span class="n">environment_variable_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">cast_to_type</span><span class="p">:</span> <span class="nb">type</span> <span class="o">=</span> <span class="nb">str</span><span class="p">)</span> <span c
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Tries to get an environment variable and cast it into a requested type.</span>
 <span class="sd">    Tries to get an environment variable and cast it into a requested type.</span>
 <span class="sd">    :return: cast_to_type object, or None if failed.</span>
 <span class="sd">    :return: cast_to_type object, or None if failed.</span>
@@ -131,41 +151,101 @@
         <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
         <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
             <span class="nb">print</span><span class="p">(</span><span class="n">e</span><span class="p">)</span>
             <span class="nb">print</span><span class="p">(</span><span class="n">e</span><span class="p">)</span>
             <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
             <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
-                <span class="sa">f</span><span class="s1">&#39;Failed to cast environment variable </span><span class="si">{</span><span class="n">environment_variable_name</span><span class="si">}</span><span class="s1"> to type </span><span class="si">{</span><span class="n">cast_to_type</span><span class="si">}</span><span class="s1">: the value </span><span class="si">{</span><span class="n">value</span><span class="si">}</span><span class="s1"> is not a valid </span><span class="si">{</spa
-    <span class="k">return</span></div>
+                <span class="sa">f</span><span class="s2">&quot;Failed to cast environment variable </span><span class="si">{</span><span class="n">environment_variable_name</span><span class="si">}</span><span class="s2"> to type </span><span class="si">{</span><span class="n">cast_to_type</span><span class="si">}</span><span class="s2">: the value </span><span class="si">{</span><span class="n">value</span><span class="si">}</span><span class="s2"> is not a valid </span><span class="si">{</sp
+            <span class="p">)</span>
+    <span class="k">return</span>
+
+
+<span class="k">def</span> <span class="nf">hydra_output_dir_resolver</span><span class="p">(</span><span class="n">ckpt_root_dir</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">):</span>
+    <span class="k">if</span> <span class="n">ckpt_root_dir</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+        <span class="n">output_dir_path</span> <span class="o">=</span> <span class="n">environment_config</span><span class="o">.</span><span class="n">PKG_CHECKPOINTS_DIR</span> <span class="o">+</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">sep</span> <span class="o">+</span> <span class="n">experiment_name</span>
+    <span class="k">else</span><span class="p">:</span>
+        <span class="n">output_dir_path</span> <span class="o">=</span> <span class="n">ckpt_root_dir</span> <span class="o">+</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">sep</span> <span class="o">+</span> <span class="n">experiment_name</span>
+    <span class="k">return</span> <span class="n">output_dir_path</span>
 
 
 
 
-<div class="viewcode-block" id="init_trainer"><a class="viewcode-back" href="../../../../super_gradients.common.environment.html#super_gradients.common.init_trainer">[docs]</a><span class="k">def</span> <span class="nf">init_trainer</span><span class="p">():</span>
+<div class="viewcode-block" id="init_trainer"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.init_trainer">[docs]</a><span class="k">def</span> <span class="nf">init_trainer</span><span class="p">():</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
-<span class="sd">    a function to initialize the super_gradients environment. This function should be the first thing to be called</span>
-<span class="sd">    by any code running super_gradients. It resolves conflicts between the different tools, packages and environments used</span>
-<span class="sd">    and prepares the super_gradients environment.</span>
+<span class="sd">    Initialize the super_gradients environment.</span>
+
+<span class="sd">    This function should be the first thing to be called by any code running super_gradients.</span>
+<span class="sd">    It resolves conflicts between the different tools, packages and environments used and prepares the super_gradients environment.</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
+    <span class="k">if</span> <span class="ow">not</span> <span class="n">environment_config</span><span class="o">.</span><span class="n">INIT_TRAINER</span><span class="p">:</span>
+        <span class="n">register_hydra_resolvers</span><span class="p">()</span>
+
+        <span class="c1"># We pop local_rank if it was specified in the args, because it would break</span>
+        <span class="n">args_local_rank</span> <span class="o">=</span> <span class="n">pop_arg</span><span class="p">(</span><span class="s2">&quot;local_rank&quot;</span><span class="p">,</span> <span class="n">default_value</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
+
+        <span class="c1"># Set local_rank with priority order (env variable &gt; args.local_rank &gt; args.default_value)</span>
+        <span class="n">environment_config</span><span class="o">.</span><span class="n">DDP_LOCAL_RANK</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">getenv</span><span class="p">(</span><span class="s2">&quot;LOCAL_RANK&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="n">args_local_rank</span><span class="p">))</span>
+        <span class="n">environment_config</span><span class="o">.</span><span class="n">INIT_TRAINER</span> <span class="o">=</span> <span class="kc">True</span></div>
+
+
+<span class="k">def</span> <span class="nf">register_hydra_resolvers</span><span class="p">():</span>
+    <span class="sd">&quot;&quot;&quot;Register all the hydra resolvers required for the super-gradients recipes.&quot;&quot;&quot;</span>
+    <span class="n">OmegaConf</span><span class="o">.</span><span class="n">register_new_resolver</span><span class="p">(</span><span class="s2">&quot;hydra_output_dir&quot;</span><span class="p">,</span> <span class="n">hydra_output_dir_resolver</span><span class="p">,</span> <span class="n">replace</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+    <span class="n">OmegaConf</span><span class="o">.</span><span class="n">register_new_resolver</span><span class="p">(</span><span class="s2">&quot;class&quot;</span><span class="p">,</span> <span class="k">lambda</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">get_cls</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">),</span> <span class="n">replace</span><span class="o">=</span><span class="k
+    <span class="n">OmegaConf</span><span class="o">.</span><span class="n">register_new_resolver</span><span class="p">(</span><span class="s2">&quot;add&quot;</span><span class="p">,</span> <span class="k">lambda</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="nb">sum</span><span class="p">(</span><span class="n">args</span><span class="p">),</span> <span class="n">replace</span><span class="o">=</span><span class="kc">True</span><span class="p"
+    <span class="n">OmegaConf</span><span class="o">.</span><span class="n">register_new_resolver</span><span class="p">(</span><span class="s2">&quot;cond&quot;</span><span class="p">,</span> <span class="k">lambda</span> <span class="n">boolean</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">x</span> <span class="k">if</span> <span class="n">boolean</span> <span class="k">else</span> <span class="
+    <span class="n">OmegaConf</span><span class="o">.</span><span class="n">register_new_resolver</span><span class="p">(</span><span class="s2">&quot;getitem&quot;</span><span class="p">,</span> <span class="k">lambda</span> <span class="n">container</span><span class="p">,</span> <span class="n">key</span><span class="p">:</span> <span class="n">container</span><span class="p">[</span><span class="n">key</span><span class="p">],</span> <span class="n">replace</span><span class="o">=</span><sp
+    <span class="n">OmegaConf</span><span class="o">.</span><span class="n">register_new_resolver</span><span class="p">(</span><span class="s2">&quot;first&quot;</span><span class="p">,</span> <span class="k">lambda</span> <span class="n">lst</span><span class="p">:</span> <span class="n">lst</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">replace</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>  <span class="c1
+    <span class="n">OmegaConf</span><span class="o">.</span><span class="n">register_new_resolver</span><span class="p">(</span><span class="s2">&quot;last&quot;</span><span class="p">,</span> <span class="k">lambda</span> <span class="n">lst</span><span class="p">:</span> <span class="n">lst</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">replace</span><span class="o">=</span><span class="kc">True</span><span class="p">)<
+
+
+<span class="k">def</span> <span class="nf">pop_arg</span><span class="p">(</span><span class="n">arg_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">default_value</span><span class="p">:</span> <span class="n">Any</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Any</span><span class="p">:</span>
+    <span class="sd">&quot;&quot;&quot;Get the specified args and remove them from argv&quot;&quot;&quot;</span>
 
 
     <span class="n">parser</span> <span class="o">=</span> <span class="n">argparse</span><span class="o">.</span><span class="n">ArgumentParser</span><span class="p">()</span>
     <span class="n">parser</span> <span class="o">=</span> <span class="n">argparse</span><span class="o">.</span><span class="n">ArgumentParser</span><span class="p">()</span>
-    <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="s2">&quot;--local_rank&quot;</span><span class="p">,</span> <span class="nb">type</span><span class="o">=</span><span class="nb">int</span><span class="p">,</span> <span class="n">default</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># used by DDP</span>
+    <span class="n">parser</span><span class="o">.</span><span class="n">add_argument</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;--</span><span class="si">{</span><span class="n">arg_name</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="n">default_value</span><span class="p">)</span>
     <span class="n">args</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">parser</span><span class="o">.</span><span class="n">parse_known_args</span><span class="p">()</span>
     <span class="n">args</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">parser</span><span class="o">.</span><span class="n">parse_known_args</span><span class="p">()</span>
 
 
-    <span class="c1"># remove any flags starting with --local_rank from the argv list</span>
-    <span class="n">to_remove</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">filter</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s1">&#39;--local_rank&#39;</span><span class="p">),</span> <span class="n">sys</span><span class="o">.</span><span class="n">argv</spa
-    <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">to_remove</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
-        <span class="k">for</span> <span class="n">val</span> <span class="ow">in</span> <span class="n">to_remove</span><span class="p">:</span>
-            <span class="n">sys</span><span class="o">.</span><span class="n">argv</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="n">val</span><span class="p">)</span>
-
-    <span class="n">environment_config</span><span class="o">.</span><span class="n">DDP_LOCAL_RANK</span> <span class="o">=</span> <span class="n">args</span><span class="o">.</span><span class="n">local_rank</span></div>
+    <span class="c1"># Remove the ddp args to not have a conflict with the use of hydra</span>
+    <span class="k">for</span> <span class="n">val</span> <span class="ow">in</span> <span class="nb">filter</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;--</span><span class="si">{</span><span class="n">arg_name</span><span class="si">}</span><span class="s2">&quot;</span><span cl
+        <span class="n">environment_config</span><span class="o">.</span><span class="n">EXTRA_ARGS</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">val</span><span class="p">)</span>
+        <span class="n">sys</span><span class="o">.</span><span class="n">argv</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="n">val</span><span class="p">)</span>
+    <span class="k">return</span> <span class="nb">vars</span><span class="p">(</span><span class="n">args</span><span class="p">)[</span><span class="n">arg_name</span><span class="p">]</span>
 
 
 
 
-<div class="viewcode-block" id="is_distributed"><a class="viewcode-back" href="../../../../super_gradients.common.environment.html#super_gradients.common.is_distributed">[docs]</a><span class="k">def</span> <span class="nf">is_distributed</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
+<div class="viewcode-block" id="is_distributed"><a class="viewcode-back" href="../../../../super_gradients.common.html#super_gradients.common.is_distributed">[docs]</a><span class="k">def</span> <span class="nf">is_distributed</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
     <span class="k">return</span> <span class="n">environment_config</span><span class="o">.</span><span class="n">DDP_LOCAL_RANK</span> <span class="o">&gt;=</span> <span class="mi">0</span></div>
     <span class="k">return</span> <span class="n">environment_config</span><span class="o">.</span><span class="n">DDP_LOCAL_RANK</span> <span class="o">&gt;=</span> <span class="mi">0</span></div>
 
 
 
 
-<div class="viewcode-block" id="multi_process_safe"><a class="viewcode-back" href="../../../../super_gradients.common.environment.html#super_gradients.common.multi_process_safe">[docs]</a><span class="k">def</span> <span class="nf">multi_process_safe</span><span class="p">(</span><span class="n">func</span><span class="p">):</span>
+<span class="k">def</span> <span class="nf">is_rank_0</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
+    <span class="sd">&quot;&quot;&quot;Check if the node was launched with torch.distributed.launch and if the node is of rank 0&quot;&quot;&quot;</span>
+    <span class="k">return</span> <span class="n">os</span><span class="o">.</span><span class="n">getenv</span><span class="p">(</span><span class="s2">&quot;LOCAL_RANK&quot;</span><span class="p">)</span> <span class="o">==</span> <span class="s2">&quot;0&quot;</span>
+
+
+<span class="k">def</span> <span class="nf">is_launched_using_sg</span><span class="p">():</span>
+    <span class="sd">&quot;&quot;&quot;Check if the current process is a subprocess launched using SG restart_script_with_ddp&quot;&quot;&quot;</span>
+    <span class="k">return</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;TORCHELASTIC_RUN_ID&quot;</span><span class="p">)</span> <span class="o">==</span> <span class="s2">&quot;sg_initiated&quot;</span>
+
+
+<span class="k">def</span> <span class="nf">is_main_process</span><span class="p">():</span>
+    <span class="sd">&quot;&quot;&quot;Check if current process is considered as the main process (i.e. is responsible for sanity check, atexit upload, ...).</span>
+<span class="sd">    The definition ensures that 1 and only 1 process follows this condition, regardless of how the run was started.</span>
+
+<span class="sd">    The rule is as follow:</span>
+<span class="sd">        - If not DDP: main process is current process</span>
+<span class="sd">        - If DDP launched using SuperGradients: main process is the launching process (rank=-1)</span>
+<span class="sd">        - If DDP launched with torch: main process is rank 0</span>
+<span class="sd">    &quot;&quot;&quot;</span>
+    <span class="k">if</span> <span class="ow">not</span> <span class="n">is_distributed</span><span class="p">():</span>  <span class="c1"># If no DDP, or DDP launching process</span>
+        <span class="k">return</span> <span class="kc">True</span>
+    <span class="k">elif</span> <span class="n">is_rank_0</span><span class="p">()</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">is_launched_using_sg</span><span class="p">():</span>  <span class="c1"># If DDP launched using torch.distributed.launch or torchrun, we need to run the check on rank 0</span>
+        <span class="k">return</span> <span class="kc">True</span>
+    <span class="k">else</span><span class="p">:</span>
+        <span class="k">return</span> <span class="kc">False</span>
+
+
+<span class="k">def</span> <span class="nf">multi_process_safe</span><span class="p">(</span><span class="n">func</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    A decorator for making sure a function runs only in main process.</span>
 <span class="sd">    A decorator for making sure a function runs only in main process.</span>
 <span class="sd">    If not in DDP mode (local_rank = -1), the function will run.</span>
 <span class="sd">    If not in DDP mode (local_rank = -1), the function will run.</span>
 <span class="sd">    If in DDP mode, the function will run only in the main process (local_rank = 0)</span>
 <span class="sd">    If in DDP mode, the function will run only in the main process (local_rank = 0)</span>
 <span class="sd">    This works only for functions with no return value</span>
 <span class="sd">    This works only for functions with no return value</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
+
     <span class="k">def</span> <span class="nf">do_nothing</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">do_nothing</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
         <span class="k">pass</span>
         <span class="k">pass</span>
 
 
@@ -176,7 +256,18 @@
         <span class="k">else</span><span class="p">:</span>
         <span class="k">else</span><span class="p">:</span>
             <span class="k">return</span> <span class="n">do_nothing</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
             <span class="k">return</span> <span class="n">do_nothing</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
 
 
-    <span class="k">return</span> <span class="n">wrapper</span></div>
+    <span class="k">return</span> <span class="n">wrapper</span>
+
+
+<span class="k">def</span> <span class="nf">find_free_port</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
+    <span class="sd">&quot;&quot;&quot;Find an available port of current machine/node.</span>
+<span class="sd">    Note: there is still a chance the port could be taken by other processes.&quot;&quot;&quot;</span>
+
+    <span class="k">with</span> <span class="n">socket</span><span class="o">.</span><span class="n">socket</span><span class="p">(</span><span class="n">socket</span><span class="o">.</span><span class="n">AF_INET</span><span class="p">,</span> <span class="n">socket</span><span class="o">.</span><span class="n">SOCK_STREAM</span><span class="p">)</span> <span class="k">as</span> <span class="n">sock</span><span class="p">:</span>
+        <span class="c1"># Binding to port 0 will cause the OS to find an available port for us</span>
+        <span class="n">sock</span><span class="o">.</span><span class="n">bind</span><span class="p">((</span><span class="s2">&quot;&quot;</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
+        <span class="n">_ip</span><span class="p">,</span> <span class="n">port</span> <span class="o">=</span> <span class="n">sock</span><span class="o">.</span><span class="n">getsockname</span><span class="p">()</span>
+    <span class="k">return</span> <span class="n">port</span>
 </pre></div>
 </pre></div>
 
 
            </div>
            </div>
@@ -206,4 +297,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.common.object_names &mdash; SuperGradients 3.0.3 documentation</title>
  7. <link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../_static/graphviz.css" type="text/css" />
  10. <link rel="stylesheet" href="../../../_static/custom.css" type="text/css" />
  11. <!--[if lt IE 9]>
  12. <script src="../../../_static/js/html5shiv.min.js"></script>
  13. <![endif]-->
  14. <script data-url_root="../../../" id="documentation_options" src="../../../_static/documentation_options.js"></script>
  15. <script src="../../../_static/jquery.js"></script>
  16. <script src="../../../_static/underscore.js"></script>
  17. <script src="../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
  18. <script src="../../../_static/doctools.js"></script>
  19. <script src="../../../_static/sphinx_highlight.js"></script>
  20. <script src="../../../_static/js/theme.js"></script>
  21. <link rel="index" title="Index" href="../../../genindex.html" />
  22. <link rel="search" title="Search" href="../../../search.html" />
  23. </head>
  24. <body class="wy-body-for-nav">
  25. <div class="wy-grid-for-nav">
  26. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  27. <div class="wy-side-scroll">
  28. <div class="wy-side-nav-search" >
  29. <a href="../../../index.html" class="icon icon-home"> SuperGradients
  30. </a>
  31. <div role="search">
  32. <form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
  33. <input type="text" name="q" placeholder="Search docs" />
  34. <input type="hidden" name="check_keywords" value="yes" />
  35. <input type="hidden" name="area" value="default" />
  36. </form>
  37. </div>
  38. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  39. <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
  40. <ul>
  41. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
  42. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#quick-installation">Quick Installation</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#what-s-new">What’s New</a></li>
  45. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#coming-soon">Coming soon</a></li>
  46. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#table-of-content">Table of Content</a></li>
  47. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#getting-started">Getting Started</a></li>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#advanced-features">Advanced Features</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#installation-methods">Installation Methods</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#documentation">Documentation</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#contributing">Contributing</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#citation">Citation</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#community">Community</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#license">License</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html#deci-platform">Deci Platform</a></li>
  57. </ul>
  58. <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
  59. <ul>
  60. <li class="toctree-l1"><a class="reference internal" href="../../../super_gradients.common.html">Common package</a></li>
  61. <li class="toctree-l1"><a class="reference internal" href="../../../super_gradients.training.html">Training package</a></li>
  62. </ul>
  63. </div>
  64. </div>
  65. </nav>
  66. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  67. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  68. <a href="../../../index.html">SuperGradients</a>
  69. </nav>
  70. <div class="wy-nav-content">
  71. <div class="rst-content">
  72. <div role="navigation" aria-label="Page navigation">
  73. <ul class="wy-breadcrumbs">
  74. <li><a href="../../../index.html" class="icon icon-home"></a> &raquo;</li>
  75. <li><a href="../../index.html">Module code</a> &raquo;</li>
  76. <li>super_gradients.common.object_names</li>
  77. <li class="wy-breadcrumbs-aside">
  78. </li>
  79. </ul>
  80. <hr/>
  81. </div>
  82. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  83. <div itemprop="articleBody">
  84. <h1>Source code for super_gradients.common.object_names</h1><div class="highlight"><pre>
  85. <div class="viewcode-block" id="Losses"><a class="viewcode-back" href="../../../super_gradients.training.html#super_gradients.training.losses.Losses">[docs]</a><span></span><span class="k">class</span> <span class="nc">Losses</span><span class="p">:</span>
  86. <span class="sd">&quot;&quot;&quot;Static class holding all the supported loss names&quot;&quot;&quot;</span>
  87. <span class="n">CROSS_ENTROPY</span> <span class="o">=</span> <span class="s2">&quot;cross_entropy&quot;</span>
  88. <span class="n">MSE</span> <span class="o">=</span> <span class="s2">&quot;mse&quot;</span>
  89. <span class="n">R_SQUARED_LOSS</span> <span class="o">=</span> <span class="s2">&quot;r_squared_loss&quot;</span>
  90. <span class="n">SHELFNET_OHEM_LOSS</span> <span class="o">=</span> <span class="s2">&quot;shelfnet_ohem_loss&quot;</span>
  91. <span class="n">SHELFNET_SE_LOSS</span> <span class="o">=</span> <span class="s2">&quot;shelfnet_se_loss&quot;</span>
  92. <span class="n">YOLOX_LOSS</span> <span class="o">=</span> <span class="s2">&quot;yolox_loss&quot;</span>
  93. <span class="n">YOLOX_FAST_LOSS</span> <span class="o">=</span> <span class="s2">&quot;yolox_fast_loss&quot;</span>
  94. <span class="n">SSD_LOSS</span> <span class="o">=</span> <span class="s2">&quot;ssd_loss&quot;</span>
  95. <span class="n">STDC_LOSS</span> <span class="o">=</span> <span class="s2">&quot;stdc_loss&quot;</span>
  96. <span class="n">BCE_DICE_LOSS</span> <span class="o">=</span> <span class="s2">&quot;bce_dice_loss&quot;</span>
  97. <span class="n">KD_LOSS</span> <span class="o">=</span> <span class="s2">&quot;kd_loss&quot;</span>
  98. <span class="n">DICE_CE_EDGE_LOSS</span> <span class="o">=</span> <span class="s2">&quot;dice_ce_edge_loss&quot;</span></div>
  99. <div class="viewcode-block" id="Metrics"><a class="viewcode-back" href="../../../super_gradients.training.html#super_gradients.training.losses.Metrics">[docs]</a><span class="k">class</span> <span class="nc">Metrics</span><span class="p">:</span>
  100. <span class="sd">&quot;&quot;&quot;Static class holding all the supported metric names&quot;&quot;&quot;</span>
  101. <span class="n">ACCURACY</span> <span class="o">=</span> <span class="s2">&quot;Accuracy&quot;</span>
  102. <span class="n">TOP5</span> <span class="o">=</span> <span class="s2">&quot;Top5&quot;</span>
  103. <span class="n">DETECTION_METRICS</span> <span class="o">=</span> <span class="s2">&quot;DetectionMetrics&quot;</span>
  104. <span class="n">DETECTION_METRICS_050_095</span> <span class="o">=</span> <span class="s2">&quot;DetectionMetrics_050_095&quot;</span>
  105. <span class="n">DETECTION_METRICS_050</span> <span class="o">=</span> <span class="s2">&quot;DetectionMetrics_050&quot;</span>
  106. <span class="n">DETECTION_METRICS_075</span> <span class="o">=</span> <span class="s2">&quot;DetectionMetrics_075&quot;</span>
  107. <span class="n">IOU</span> <span class="o">=</span> <span class="s2">&quot;IoU&quot;</span>
  108. <span class="n">BINARY_IOU</span> <span class="o">=</span> <span class="s2">&quot;BinaryIOU&quot;</span>
  109. <span class="n">DICE</span> <span class="o">=</span> <span class="s2">&quot;Dice&quot;</span>
  110. <span class="n">BINARY_DICE</span> <span class="o">=</span> <span class="s2">&quot;BinaryDice&quot;</span>
  111. <span class="n">PIXEL_ACCURACY</span> <span class="o">=</span> <span class="s2">&quot;PixelAccuracy&quot;</span></div>
  112. <div class="viewcode-block" id="Transforms"><a class="viewcode-back" href="../../../super_gradients.training.html#super_gradients.training.losses.Transforms">[docs]</a><span class="k">class</span> <span class="nc">Transforms</span><span class="p">:</span>
  113. <span class="sd">&quot;&quot;&quot;Static class holding all the supported transform names&quot;&quot;&quot;</span>
  114. <span class="c1"># From SG</span>
  115. <span class="n">SegRandomFlip</span> <span class="o">=</span> <span class="s2">&quot;SegRandomFlip&quot;</span>
  116. <span class="n">SegResize</span> <span class="o">=</span> <span class="s2">&quot;SegResize&quot;</span>
  117. <span class="n">SegRescale</span> <span class="o">=</span> <span class="s2">&quot;SegRescale&quot;</span>
  118. <span class="n">SegRandomRescale</span> <span class="o">=</span> <span class="s2">&quot;SegRandomRescale&quot;</span>
  119. <span class="n">SegRandomRotate</span> <span class="o">=</span> <span class="s2">&quot;SegRandomRotate&quot;</span>
  120. <span class="n">SegCropImageAndMask</span> <span class="o">=</span> <span class="s2">&quot;SegCropImageAndMask&quot;</span>
  121. <span class="n">SegRandomGaussianBlur</span> <span class="o">=</span> <span class="s2">&quot;SegRandomGaussianBlur&quot;</span>
  122. <span class="n">SegPadShortToCropSize</span> <span class="o">=</span> <span class="s2">&quot;SegPadShortToCropSize&quot;</span>
  123. <span class="n">SegColorJitter</span> <span class="o">=</span> <span class="s2">&quot;SegColorJitter&quot;</span>
  124. <span class="n">DetectionMosaic</span> <span class="o">=</span> <span class="s2">&quot;DetectionMosaic&quot;</span>
  125. <span class="n">DetectionRandomAffine</span> <span class="o">=</span> <span class="s2">&quot;DetectionRandomAffine&quot;</span>
  126. <span class="n">DetectionMixup</span> <span class="o">=</span> <span class="s2">&quot;DetectionMixup&quot;</span>
  127. <span class="n">DetectionHSV</span> <span class="o">=</span> <span class="s2">&quot;DetectionHSV&quot;</span>
  128. <span class="n">DetectionHorizontalFlip</span> <span class="o">=</span> <span class="s2">&quot;DetectionHorizontalFlip&quot;</span>
  129. <span class="n">DetectionPaddedRescale</span> <span class="o">=</span> <span class="s2">&quot;DetectionPaddedRescale&quot;</span>
  130. <span class="n">DetectionTargetsFormat</span> <span class="o">=</span> <span class="s2">&quot;DetectionTargetsFormat&quot;</span>
  131. <span class="n">DetectionTargetsFormatTransform</span> <span class="o">=</span> <span class="s2">&quot;DetectionTargetsFormatTransform&quot;</span>
  132. <span class="n">RandomResizedCropAndInterpolation</span> <span class="o">=</span> <span class="s2">&quot;RandomResizedCropAndInterpolation&quot;</span>
  133. <span class="n">RandAugmentTransform</span> <span class="o">=</span> <span class="s2">&quot;RandAugmentTransform&quot;</span>
  134. <span class="n">Lighting</span> <span class="o">=</span> <span class="s2">&quot;Lighting&quot;</span>
  135. <span class="n">RandomErase</span> <span class="o">=</span> <span class="s2">&quot;RandomErase&quot;</span>
  136. <span class="c1"># From torch</span>
  137. <span class="n">Compose</span> <span class="o">=</span> <span class="s2">&quot;Compose&quot;</span>
  138. <span class="n">ToTensor</span> <span class="o">=</span> <span class="s2">&quot;ToTensor&quot;</span>
  139. <span class="n">PILToTensor</span> <span class="o">=</span> <span class="s2">&quot;PILToTensor&quot;</span>
  140. <span class="n">ConvertImageDtype</span> <span class="o">=</span> <span class="s2">&quot;ConvertImageDtype&quot;</span>
  141. <span class="n">ToPILImage</span> <span class="o">=</span> <span class="s2">&quot;ToPILImage&quot;</span>
  142. <span class="n">Normalize</span> <span class="o">=</span> <span class="s2">&quot;Normalize&quot;</span>
  143. <span class="n">Resize</span> <span class="o">=</span> <span class="s2">&quot;Resize&quot;</span>
  144. <span class="n">CenterCrop</span> <span class="o">=</span> <span class="s2">&quot;CenterCrop&quot;</span>
  145. <span class="n">Pad</span> <span class="o">=</span> <span class="s2">&quot;Pad&quot;</span>
  146. <span class="n">Lambda</span> <span class="o">=</span> <span class="s2">&quot;Lambda&quot;</span>
  147. <span class="n">RandomApply</span> <span class="o">=</span> <span class="s2">&quot;RandomApply&quot;</span>
  148. <span class="n">RandomChoice</span> <span class="o">=</span> <span class="s2">&quot;RandomChoice&quot;</span>
  149. <span class="n">RandomOrder</span> <span class="o">=</span> <span class="s2">&quot;RandomOrder&quot;</span>
  150. <span class="n">RandomCrop</span> <span class="o">=</span> <span class="s2">&quot;RandomCrop&quot;</span>
  151. <span class="n">RandomHorizontalFlip</span> <span class="o">=</span> <span class="s2">&quot;RandomHorizontalFlip&quot;</span>
  152. <span class="n">RandomVerticalFlip</span> <span class="o">=</span> <span class="s2">&quot;RandomVerticalFlip&quot;</span>
  153. <span class="n">RandomResizedCrop</span> <span class="o">=</span> <span class="s2">&quot;RandomResizedCrop&quot;</span>
  154. <span class="n">FiveCrop</span> <span class="o">=</span> <span class="s2">&quot;FiveCrop&quot;</span>
  155. <span class="n">TenCrop</span> <span class="o">=</span> <span class="s2">&quot;TenCrop&quot;</span>
  156. <span class="n">LinearTransformation</span> <span class="o">=</span> <span class="s2">&quot;LinearTransformation&quot;</span>
  157. <span class="n">ColorJitter</span> <span class="o">=</span> <span class="s2">&quot;ColorJitter&quot;</span>
  158. <span class="n">RandomRotation</span> <span class="o">=</span> <span class="s2">&quot;RandomRotation&quot;</span>
  159. <span class="n">RandomAffine</span> <span class="o">=</span> <span class="s2">&quot;RandomAffine&quot;</span>
  160. <span class="n">Grayscale</span> <span class="o">=</span> <span class="s2">&quot;Grayscale&quot;</span>
  161. <span class="n">RandomGrayscale</span> <span class="o">=</span> <span class="s2">&quot;RandomGrayscale&quot;</span>
  162. <span class="n">RandomPerspective</span> <span class="o">=</span> <span class="s2">&quot;RandomPerspective&quot;</span>
  163. <span class="n">RandomErasing</span> <span class="o">=</span> <span class="s2">&quot;RandomErasing&quot;</span>
  164. <span class="n">GaussianBlur</span> <span class="o">=</span> <span class="s2">&quot;GaussianBlur&quot;</span>
  165. <span class="n">InterpolationMode</span> <span class="o">=</span> <span class="s2">&quot;InterpolationMode&quot;</span>
  166. <span class="n">RandomInvert</span> <span class="o">=</span> <span class="s2">&quot;RandomInvert&quot;</span>
  167. <span class="n">RandomPosterize</span> <span class="o">=</span> <span class="s2">&quot;RandomPosterize&quot;</span>
  168. <span class="n">RandomSolarize</span> <span class="o">=</span> <span class="s2">&quot;RandomSolarize&quot;</span>
  169. <span class="n">RandomAdjustSharpness</span> <span class="o">=</span> <span class="s2">&quot;RandomAdjustSharpness&quot;</span>
  170. <span class="n">RandomAutocontrast</span> <span class="o">=</span> <span class="s2">&quot;RandomAutocontrast&quot;</span>
  171. <span class="n">RandomEqualize</span> <span class="o">=</span> <span class="s2">&quot;RandomEqualize&quot;</span></div>
  172. <span class="k">class</span> <span class="nc">Optimizers</span><span class="p">:</span>
  173. <span class="sd">&quot;&quot;&quot;Static class holding all the supported optimizer names&quot;&quot;&quot;</span>
  174. <span class="n">SGD</span> <span class="o">=</span> <span class="s2">&quot;SGD&quot;</span>
  175. <span class="n">ADAM</span> <span class="o">=</span> <span class="s2">&quot;Adam&quot;</span>
  176. <span class="n">RMS_PROP</span> <span class="o">=</span> <span class="s2">&quot;RMSprop&quot;</span>
  177. <span class="n">RMS_PROP_TF</span> <span class="o">=</span> <span class="s2">&quot;RMSpropTF&quot;</span>
  178. <span class="n">LAMB</span> <span class="o">=</span> <span class="s2">&quot;Lamb&quot;</span>
  179. <span class="k">class</span> <span class="nc">Callbacks</span><span class="p">:</span>
  180. <span class="sd">&quot;&quot;&quot;Static class holding all the supported callback names&quot;&quot;&quot;</span>
  181. <span class="n">DECI_LAB_UPLOAD</span> <span class="o">=</span> <span class="s2">&quot;DeciLabUploadCallback&quot;</span>
  182. <span class="n">LR_CALLBACK_BASE</span> <span class="o">=</span> <span class="s2">&quot;LRCallbackBase&quot;</span>
  183. <span class="n">LR_SCHEDULER</span> <span class="o">=</span> <span class="s2">&quot;LRSchedulerCallback&quot;</span>
  184. <span class="n">METRICS_UPDATE</span> <span class="o">=</span> <span class="s2">&quot;MetricsUpdateCallback&quot;</span>
  185. <span class="n">MODEL_CONVERSION_CHECK</span> <span class="o">=</span> <span class="s2">&quot;ModelConversionCheckCallback&quot;</span>
  186. <span class="n">EARLY_STOP</span> <span class="o">=</span> <span class="s2">&quot;EarlyStop&quot;</span>
  187. <span class="n">DETECTION_MULTISCALE_PREPREDICTION</span> <span class="o">=</span> <span class="s2">&quot;DetectionMultiscalePrePredictionCallback&quot;</span>
  188. <span class="n">YOLOX_TRAINING_STAGE_SWITCH</span> <span class="o">=</span> <span class="s2">&quot;YoloXTrainingStageSwitchCallback&quot;</span>
  189. <span class="k">class</span> <span class="nc">LRSchedulers</span><span class="p">:</span>
  190. <span class="sd">&quot;&quot;&quot;Static class to hold all the supported LR Scheduler names&quot;&quot;&quot;</span>
  191. <span class="n">STEP</span> <span class="o">=</span> <span class="s2">&quot;step&quot;</span>
  192. <span class="n">POLY</span> <span class="o">=</span> <span class="s2">&quot;poly&quot;</span>
  193. <span class="n">COSINE</span> <span class="o">=</span> <span class="s2">&quot;cosine&quot;</span>
  194. <span class="n">EXP</span> <span class="o">=</span> <span class="s2">&quot;exp&quot;</span>
  195. <span class="n">FUNCTION</span> <span class="o">=</span> <span class="s2">&quot;function&quot;</span>
  196. <span class="k">class</span> <span class="nc">LRWarmups</span><span class="p">:</span>
  197. <span class="sd">&quot;&quot;&quot;Static class to hold all the supported LR Warmup names&quot;&quot;&quot;</span>
  198. <span class="n">LINEAR_STEP</span> <span class="o">=</span> <span class="s2">&quot;linear_step&quot;</span>
  199. <span class="k">class</span> <span class="nc">Samplers</span><span class="p">:</span>
  200. <span class="sd">&quot;&quot;&quot;Static class to hold all the supported Samplers names&quot;&quot;&quot;</span>
  201. <span class="n">INFINITE</span> <span class="o">=</span> <span class="s2">&quot;InfiniteSampler&quot;</span>
  202. <span class="n">REPEAT_AUG</span> <span class="o">=</span> <span class="s2">&quot;RepeatAugSampler&quot;</span>
  203. <span class="n">DISTRIBUTED</span> <span class="o">=</span> <span class="s2">&quot;DistributedSampler&quot;</span>
  204. <span class="k">class</span> <span class="nc">ContextModules</span><span class="p">:</span>
  205. <span class="sd">&quot;&quot;&quot;Static class to hold all the segmentation context module names&quot;&quot;&quot;</span>
  206. <span class="n">ASPP</span> <span class="o">=</span> <span class="s2">&quot;ASPP&quot;</span>
  207. <span class="n">SPPM</span> <span class="o">=</span> <span class="s2">&quot;SPPM&quot;</span>
  208. <span class="k">class</span> <span class="nc">Models</span><span class="p">:</span>
  209. <span class="sd">&quot;&quot;&quot;Static class to hold all the available model names&quot;&quot;&quot;</span>
  210. <span class="n">RESNET18</span> <span class="o">=</span> <span class="s2">&quot;resnet18&quot;</span>
  211. <span class="n">RESNET34</span> <span class="o">=</span> <span class="s2">&quot;resnet34&quot;</span>
  212. <span class="n">RESNET50_3343</span> <span class="o">=</span> <span class="s2">&quot;resnet50_3343&quot;</span>
  213. <span class="n">RESNET50</span> <span class="o">=</span> <span class="s2">&quot;resnet50&quot;</span>
  214. <span class="n">RESNET101</span> <span class="o">=</span> <span class="s2">&quot;resnet101&quot;</span>
  215. <span class="n">RESNET152</span> <span class="o">=</span> <span class="s2">&quot;resnet152&quot;</span>
  216. <span class="n">RESNET18_CIFAR</span> <span class="o">=</span> <span class="s2">&quot;resnet18_cifar&quot;</span>
  217. <span class="n">CUSTOM_RESNET</span> <span class="o">=</span> <span class="s2">&quot;custom_resnet&quot;</span>
  218. <span class="n">CUSTOM_RESNET50</span> <span class="o">=</span> <span class="s2">&quot;custom_resnet50&quot;</span>
  219. <span class="n">CUSTOM_RESNET_CIFAR</span> <span class="o">=</span> <span class="s2">&quot;custom_resnet_cifar&quot;</span>
  220. <span class="n">CUSTOM_RESNET50_CIFAR</span> <span class="o">=</span> <span class="s2">&quot;custom_resnet50_cifar&quot;</span>
  221. <span class="n">MOBILENET_V2</span> <span class="o">=</span> <span class="s2">&quot;mobilenet_v2&quot;</span>
  222. <span class="n">MOBILE_NET_V2_135</span> <span class="o">=</span> <span class="s2">&quot;mobile_net_v2_135&quot;</span>
  223. <span class="n">CUSTOM_MOBILENET_V2</span> <span class="o">=</span> <span class="s2">&quot;custom_mobilenet_v2&quot;</span>
  224. <span class="n">MOBILENET_V3_LARGE</span> <span class="o">=</span> <span class="s2">&quot;mobilenet_v3_large&quot;</span>
  225. <span class="n">MOBILENET_V3_SMALL</span> <span class="o">=</span> <span class="s2">&quot;mobilenet_v3_small&quot;</span>
  226. <span class="n">MOBILENET_V3_CUSTOM</span> <span class="o">=</span> <span class="s2">&quot;mobilenet_v3_custom&quot;</span>
  227. <span class="n">CUSTOM_DENSENET</span> <span class="o">=</span> <span class="s2">&quot;custom_densenet&quot;</span>
  228. <span class="n">DENSENET121</span> <span class="o">=</span> <span class="s2">&quot;densenet121&quot;</span>
  229. <span class="n">DENSENET161</span> <span class="o">=</span> <span class="s2">&quot;densenet161&quot;</span>
  230. <span class="n">DENSENET169</span> <span class="o">=</span> <span class="s2">&quot;densenet169&quot;</span>
  231. <span class="n">DENSENET201</span> <span class="o">=</span> <span class="s2">&quot;densenet201&quot;</span>
  232. <span class="n">SHELFNET18_LW</span> <span class="o">=</span> <span class="s2">&quot;shelfnet18_lw&quot;</span>
  233. <span class="n">SHELFNET34_LW</span> <span class="o">=</span> <span class="s2">&quot;shelfnet34_lw&quot;</span>
  234. <span class="n">SHELFNET50_3343</span> <span class="o">=</span> <span class="s2">&quot;shelfnet50_3343&quot;</span>
  235. <span class="n">SHELFNET50</span> <span class="o">=</span> <span class="s2">&quot;shelfnet50&quot;</span>
  236. <span class="n">SHELFNET101</span> <span class="o">=</span> <span class="s2">&quot;shelfnet101&quot;</span>
  237. <span class="n">SHUFFLENET_V2_X0_5</span> <span class="o">=</span> <span class="s2">&quot;shufflenet_v2_x0_5&quot;</span>
  238. <span class="n">SHUFFLENET_V2_X1_0</span> <span class="o">=</span> <span class="s2">&quot;shufflenet_v2_x1_0&quot;</span>
  239. <span class="n">SHUFFLENET_V2_X1_5</span> <span class="o">=</span> <span class="s2">&quot;shufflenet_v2_x1_5&quot;</span>
  240. <span class="n">SHUFFLENET_V2_X2_0</span> <span class="o">=</span> <span class="s2">&quot;shufflenet_v2_x2_0&quot;</span>
  241. <span class="n">SHUFFLENET_V2_CUSTOM5</span> <span class="o">=</span> <span class="s2">&quot;shufflenet_v2_custom5&quot;</span>
  242. <span class="n">DARKNET53</span> <span class="o">=</span> <span class="s2">&quot;darknet53&quot;</span>
  243. <span class="n">CSP_DARKNET53</span> <span class="o">=</span> <span class="s2">&quot;csp_darknet53&quot;</span>
  244. <span class="n">RESNEXT50</span> <span class="o">=</span> <span class="s2">&quot;resnext50&quot;</span>
  245. <span class="n">RESNEXT101</span> <span class="o">=</span> <span class="s2">&quot;resnext101&quot;</span>
  246. <span class="n">GOOGLENET_V1</span> <span class="o">=</span> <span class="s2">&quot;googlenet_v1&quot;</span>
  247. <span class="n">EFFICIENTNET_B0</span> <span class="o">=</span> <span class="s2">&quot;efficientnet_b0&quot;</span>
  248. <span class="n">EFFICIENTNET_B1</span> <span class="o">=</span> <span class="s2">&quot;efficientnet_b1&quot;</span>
  249. <span class="n">EFFICIENTNET_B2</span> <span class="o">=</span> <span class="s2">&quot;efficientnet_b2&quot;</span>
  250. <span class="n">EFFICIENTNET_B3</span> <span class="o">=</span> <span class="s2">&quot;efficientnet_b3&quot;</span>
  251. <span class="n">EFFICIENTNET_B4</span> <span class="o">=</span> <span class="s2">&quot;efficientnet_b4&quot;</span>
  252. <span class="n">EFFICIENTNET_B5</span> <span class="o">=</span> <span class="s2">&quot;efficientnet_b5&quot;</span>
  253. <span class="n">EFFICIENTNET_B6</span> <span class="o">=</span> <span class="s2">&quot;efficientnet_b6&quot;</span>
  254. <span class="n">EFFICIENTNET_B7</span> <span class="o">=</span> <span class="s2">&quot;efficientnet_b7&quot;</span>
  255. <span class="n">EFFICIENTNET_B8</span> <span class="o">=</span> <span class="s2">&quot;efficientnet_b8&quot;</span>
  256. <span class="n">EFFICIENTNET_L2</span> <span class="o">=</span> <span class="s2">&quot;efficientnet_l2&quot;</span>
  257. <span class="n">CUSTOMIZEDEFFICIENTNET</span> <span class="o">=</span> <span class="s2">&quot;CustomizedEfficientnet&quot;</span>
  258. <span class="n">REGNETY200</span> <span class="o">=</span> <span class="s2">&quot;regnetY200&quot;</span>
  259. <span class="n">REGNETY400</span> <span class="o">=</span> <span class="s2">&quot;regnetY400&quot;</span>
  260. <span class="n">REGNETY600</span> <span class="o">=</span> <span class="s2">&quot;regnetY600&quot;</span>
  261. <span class="n">REGNETY800</span> <span class="o">=</span> <span class="s2">&quot;regnetY800&quot;</span>
  262. <span class="n">CUSTOM_REGNET</span> <span class="o">=</span> <span class="s2">&quot;custom_regnet&quot;</span>
  263. <span class="n">CUSTOM_ANYNET</span> <span class="o">=</span> <span class="s2">&quot;custom_anynet&quot;</span>
  264. <span class="n">NAS_REGNET</span> <span class="o">=</span> <span class="s2">&quot;nas_regnet&quot;</span>
  265. <span class="n">YOLOX_N</span> <span class="o">=</span> <span class="s2">&quot;yolox_n&quot;</span>
  266. <span class="n">YOLOX_T</span> <span class="o">=</span> <span class="s2">&quot;yolox_t&quot;</span>
  267. <span class="n">YOLOX_S</span> <span class="o">=</span> <span class="s2">&quot;yolox_s&quot;</span>
  268. <span class="n">YOLOX_M</span> <span class="o">=</span> <span class="s2">&quot;yolox_m&quot;</span>
  269. <span class="n">YOLOX_L</span> <span class="o">=</span> <span class="s2">&quot;yolox_l&quot;</span>
  270. <span class="n">YOLOX_X</span> <span class="o">=</span> <span class="s2">&quot;yolox_x&quot;</span>
  271. <span class="n">CUSTOM_YOLO_X</span> <span class="o">=</span> <span class="s2">&quot;custom_yolox&quot;</span>
  272. <span class="n">SSD_MOBILENET_V1</span> <span class="o">=</span> <span class="s2">&quot;ssd_mobilenet_v1&quot;</span>
  273. <span class="n">SSD_LITE_MOBILENET_V2</span> <span class="o">=</span> <span class="s2">&quot;ssd_lite_mobilenet_v2&quot;</span>
  274. <span class="n">REPVGG_A0</span> <span class="o">=</span> <span class="s2">&quot;repvgg_a0&quot;</span>
  275. <span class="n">REPVGG_A1</span> <span class="o">=</span> <span class="s2">&quot;repvgg_a1&quot;</span>
  276. <span class="n">REPVGG_A2</span> <span class="o">=</span> <span class="s2">&quot;repvgg_a2&quot;</span>
  277. <span class="n">REPVGG_B0</span> <span class="o">=</span> <span class="s2">&quot;repvgg_b0&quot;</span>
  278. <span class="n">REPVGG_B1</span> <span class="o">=</span> <span class="s2">&quot;repvgg_b1&quot;</span>
  279. <span class="n">REPVGG_B2</span> <span class="o">=</span> <span class="s2">&quot;repvgg_b2&quot;</span>
  280. <span class="n">REPVGG_B3</span> <span class="o">=</span> <span class="s2">&quot;repvgg_b3&quot;</span>
  281. <span class="n">REPVGG_D2SE</span> <span class="o">=</span> <span class="s2">&quot;repvgg_d2se&quot;</span>
  282. <span class="n">REPVGG_CUSTOM</span> <span class="o">=</span> <span class="s2">&quot;repvgg_custom&quot;</span>
  283. <span class="n">DDRNET_23</span> <span class="o">=</span> <span class="s2">&quot;ddrnet_23&quot;</span>
  284. <span class="n">DDRNET_23_SLIM</span> <span class="o">=</span> <span class="s2">&quot;ddrnet_23_slim&quot;</span>
  285. <span class="n">CUSTOM_DDRNET_23</span> <span class="o">=</span> <span class="s2">&quot;custom_ddrnet_23&quot;</span>
  286. <span class="n">STDC1_CLASSIFICATION</span> <span class="o">=</span> <span class="s2">&quot;stdc1_classification&quot;</span>
  287. <span class="n">STDC2_CLASSIFICATION</span> <span class="o">=</span> <span class="s2">&quot;stdc2_classification&quot;</span>
  288. <span class="n">STDC1_SEG</span> <span class="o">=</span> <span class="s2">&quot;stdc1_seg&quot;</span>
  289. <span class="n">STDC1_SEG50</span> <span class="o">=</span> <span class="s2">&quot;stdc1_seg50&quot;</span>
  290. <span class="n">STDC1_SEG75</span> <span class="o">=</span> <span class="s2">&quot;stdc1_seg75&quot;</span>
  291. <span class="n">STDC2_SEG</span> <span class="o">=</span> <span class="s2">&quot;stdc2_seg&quot;</span>
  292. <span class="n">STDC2_SEG50</span> <span class="o">=</span> <span class="s2">&quot;stdc2_seg50&quot;</span>
  293. <span class="n">STDC2_SEG75</span> <span class="o">=</span> <span class="s2">&quot;stdc2_seg75&quot;</span>
  294. <span class="n">CUSTOM_STDC</span> <span class="o">=</span> <span class="s2">&quot;custom_stdc&quot;</span>
  295. <span class="n">REGSEG48</span> <span class="o">=</span> <span class="s2">&quot;regseg48&quot;</span>
  296. <span class="n">KD_MODULE</span> <span class="o">=</span> <span class="s2">&quot;kd_module&quot;</span>
  297. <span class="n">VIT_BASE</span> <span class="o">=</span> <span class="s2">&quot;vit_base&quot;</span>
  298. <span class="n">VIT_LARGE</span> <span class="o">=</span> <span class="s2">&quot;vit_large&quot;</span>
  299. <span class="n">VIT_HUGE</span> <span class="o">=</span> <span class="s2">&quot;vit_huge&quot;</span>
  300. <span class="n">BEIT_BASE_PATCH16_224</span> <span class="o">=</span> <span class="s2">&quot;beit_base_patch16_224&quot;</span>
  301. <span class="n">BEIT_LARGE_PATCH16_224</span> <span class="o">=</span> <span class="s2">&quot;beit_large_patch16_224&quot;</span>
  302. <span class="n">PP_LITE_T_SEG</span> <span class="o">=</span> <span class="s2">&quot;pp_lite_t_seg&quot;</span>
  303. <span class="n">PP_LITE_T_SEG50</span> <span class="o">=</span> <span class="s2">&quot;pp_lite_t_seg50&quot;</span>
  304. <span class="n">PP_LITE_T_SEG75</span> <span class="o">=</span> <span class="s2">&quot;pp_lite_t_seg75&quot;</span>
  305. <span class="n">PP_LITE_B_SEG</span> <span class="o">=</span> <span class="s2">&quot;pp_lite_b_seg&quot;</span>
  306. <span class="n">PP_LITE_B_SEG50</span> <span class="o">=</span> <span class="s2">&quot;pp_lite_b_seg50&quot;</span>
  307. <span class="n">PP_LITE_B_SEG75</span> <span class="o">=</span> <span class="s2">&quot;pp_lite_b_seg75&quot;</span>
  308. <span class="n">UNET_CUSTOM</span> <span class="o">=</span> <span class="s2">&quot;unet_custom&quot;</span>
  309. </pre></div>
  310. </div>
  311. </div>
  312. <footer>
  313. <hr/>
  314. <div role="contentinfo">
  315. <p>&#169; Copyright 2021, SuperGradients team.</p>
  316. </div>
  317. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  318. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  319. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  320. </footer>
  321. </div>
  322. </div>
  323. </section>
  324. </div>
  325. <script>
  326. jQuery(function () {
  327. SphinxRtdTheme.Navigation.enable(true);
  328. });
  329. </script>
  330. </body>
  331. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.dataloaders.dataloaders &mdash; SuperGradients 3.0.3 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
  11. <!--[if lt IE 9]>
  12. <script src="../../../../_static/js/html5shiv.min.js"></script>
  13. <![endif]-->
  14. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  15. <script src="../../../../_static/jquery.js"></script>
  16. <script src="../../../../_static/underscore.js"></script>
  17. <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
  18. <script src="../../../../_static/doctools.js"></script>
  19. <script src="../../../../_static/sphinx_highlight.js"></script>
  20. <script src="../../../../_static/js/theme.js"></script>
  21. <link rel="index" title="Index" href="../../../../genindex.html" />
  22. <link rel="search" title="Search" href="../../../../search.html" />
  23. </head>
  24. <body class="wy-body-for-nav">
  25. <div class="wy-grid-for-nav">
  26. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  27. <div class="wy-side-scroll">
  28. <div class="wy-side-nav-search" >
  29. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  30. </a>
  31. <div role="search">
  32. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  33. <input type="text" name="q" placeholder="Search docs" />
  34. <input type="hidden" name="check_keywords" value="yes" />
  35. <input type="hidden" name="area" value="default" />
  36. </form>
  37. </div>
  38. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  39. <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
  40. <ul>
  41. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
  42. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
  45. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
  46. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
  47. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
  57. </ul>
  58. <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
  59. <ul>
  60. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  61. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  62. </ul>
  63. </div>
  64. </div>
  65. </nav>
  66. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  67. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  68. <a href="../../../../index.html">SuperGradients</a>
  69. </nav>
  70. <div class="wy-nav-content">
  71. <div class="rst-content">
  72. <div role="navigation" aria-label="Page navigation">
  73. <ul class="wy-breadcrumbs">
  74. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  75. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  76. <li>super_gradients.training.dataloaders.dataloaders</li>
  77. <li class="wy-breadcrumbs-aside">
  78. </li>
  79. </ul>
  80. <hr/>
  81. </div>
  82. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  83. <div itemprop="articleBody">
  84. <h1>Source code for super_gradients.training.dataloaders.dataloaders</h1><div class="highlight"><pre>
  85. <span></span><span class="kn">import</span> <span class="nn">os.path</span>
  86. <span class="kn">import</span> <span class="nn">pkg_resources</span>
  87. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span>
  88. <span class="kn">import</span> <span class="nn">hydra</span>
  89. <span class="kn">from</span> <span class="nn">hydra</span> <span class="kn">import</span> <span class="n">compose</span><span class="p">,</span> <span class="n">initialize_config_dir</span>
  90. <span class="kn">from</span> <span class="nn">hydra.core.global_hydra</span> <span class="kn">import</span> <span class="n">GlobalHydra</span>
  91. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  92. <span class="kn">import</span> <span class="nn">torch</span>
  93. <span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">BatchSampler</span><span class="p">,</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">TensorDataset</span>
  94. <span class="kn">import</span> <span class="nn">super_gradients</span>
  95. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.detection_datasets.pascal_voc_detection</span> <span class="kn">import</span> <span class="p">(</span>
  96. <span class="n">PascalVOCUnifiedDetectionTrainDataset</span><span class="p">,</span>
  97. <span class="n">PascalVOCDetectionDataset</span><span class="p">,</span>
  98. <span class="p">)</span>
  99. <span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">get_param</span>
  100. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.hydra_utils</span> <span class="kn">import</span> <span class="n">normalize_path</span>
  101. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets</span> <span class="kn">import</span> <span class="n">ImageNetDataset</span>
  102. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.detection_datasets</span> <span class="kn">import</span> <span class="n">COCODetectionDataset</span>
  103. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.classification_datasets.cifar</span> <span class="kn">import</span> <span class="p">(</span>
  104. <span class="n">Cifar10</span><span class="p">,</span>
  105. <span class="n">Cifar100</span><span class="p">,</span>
  106. <span class="p">)</span>
  107. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets</span> <span class="kn">import</span> <span class="p">(</span>
  108. <span class="n">CityscapesDataset</span><span class="p">,</span>
  109. <span class="n">CoCoSegmentationDataSet</span><span class="p">,</span>
  110. <span class="n">PascalVOC2012SegmentationDataSet</span><span class="p">,</span>
  111. <span class="n">PascalVOCAndAUGUnifiedDataset</span><span class="p">,</span>
  112. <span class="n">SuperviselyPersonsDataset</span><span class="p">,</span>
  113. <span class="p">)</span>
  114. <span class="kn">from</span> <span class="nn">super_gradients.common.factories.samplers_factory</span> <span class="kn">import</span> <span class="n">SamplersFactory</span>
  115. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.distributed_training_utils</span> <span class="kn">import</span> <span class="p">(</span>
  116. <span class="n">wait_for_the_master</span><span class="p">,</span>
  117. <span class="n">get_local_rank</span><span class="p">,</span>
  118. <span class="p">)</span>
  119. <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
  120. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.utils</span> <span class="kn">import</span> <span class="n">override_default_params_without_nones</span>
  121. <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
  122. <div class="viewcode-block" id="get_data_loader"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.get_data_loader">[docs]</a><span class="k">def</span> <span class="nf">get_data_loader</span><span class="p">(</span><span class="n">config_name</span><span class="p">,</span> <span class="n">dataset_cls</span><span class="p">,</span> <span class="n">train</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  123. <span class="sd">&quot;&quot;&quot;</span>
  124. <span class="sd"> Class for creating dataloaders for taking defaults from yaml files in src/super_gradients/recipes.</span>
  125. <span class="sd"> :param config_name: yaml config filename in recipes (for example coco2017_yolox).</span>
  126. <span class="sd"> :param dataset_cls: torch dataset uninitialized class.</span>
  127. <span class="sd"> :param train: controls whether to take</span>
  128. <span class="sd"> cfg.dataset_params.train_dataloader_params or cfg.dataset_params.valid_dataloader_params as defaults for the dataset constructor</span>
  129. <span class="sd"> and</span>
  130. <span class="sd"> cfg.dataset_params.train_dataset_params or cfg.dataset_params.valid_dataset_params as defaults for DataLoader contructor.</span>
  131. <span class="sd"> :param dataset_params: dataset params that override the yaml configured defaults, then passed to the dataset_cls.__init__.</span>
  132. <span class="sd"> :param dataloader_params: DataLoader params that override the yaml configured defaults, then passed to the DataLoader.__init__</span>
  133. <span class="sd"> :return: DataLoader</span>
  134. <span class="sd"> &quot;&quot;&quot;</span>
  135. <span class="k">if</span> <span class="n">dataloader_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  136. <span class="n">dataloader_params</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
  137. <span class="k">if</span> <span class="n">dataset_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  138. <span class="n">dataset_params</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
  139. <span class="n">GlobalHydra</span><span class="o">.</span><span class="n">instance</span><span class="p">()</span><span class="o">.</span><span class="n">clear</span><span class="p">()</span>
  140. <span class="n">sg_recipes_dir</span> <span class="o">=</span> <span class="n">pkg_resources</span><span class="o">.</span><span class="n">resource_filename</span><span class="p">(</span><span class="s2">&quot;super_gradients.recipes&quot;</span><span class="p">,</span> <span class="s2">&quot;&quot;</span><span class="p">)</span>
  141. <span class="n">dataset_config</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="s2">&quot;dataset_params&quot;</span><span class="p">,</span> <span class="n">config_name</span><span class="p">)</span>
  142. <span class="k">with</span> <span class="n">initialize_config_dir</span><span class="p">(</span><span class="n">config_dir</span><span class="o">=</span><span class="n">normalize_path</span><span class="p">(</span><span class="n">sg_recipes_dir</span><span class="p">),</span> <span class="n">version_base</span><span class="o">=</span><span class="s2">&quot;1.2&quot;</span><span class="p">):</span>
  143. <span class="c1"># config is relative to a module</span>
  144. <span class="n">cfg</span> <span class="o">=</span> <span class="n">compose</span><span class="p">(</span><span class="n">config_name</span><span class="o">=</span><span class="n">normalize_path</span><span class="p">(</span><span class="n">dataset_config</span><span class="p">))</span>
  145. <span class="n">dataset_params</span> <span class="o">=</span> <span class="n">_process_dataset_params</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="n">dataset_params</span><span class="p">,</span> <span class="n">train</span><span class="p">)</span>
  146. <span class="n">local_rank</span> <span class="o">=</span> <span class="n">get_local_rank</span><span class="p">()</span>
  147. <span class="k">with</span> <span class="n">wait_for_the_master</span><span class="p">(</span><span class="n">local_rank</span><span class="p">):</span>
  148. <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset_cls</span><span class="p">(</span><span class="o">**</span><span class="n">dataset_params</span><span class="p">)</span>
  149. <span class="k">if</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="s2">&quot;dataset_params&quot;</span><span class="p">):</span>
  150. <span class="n">dataset</span><span class="o">.</span><span class="n">dataset_params</span> <span class="o">=</span> <span class="n">dataset_params</span>
  151. <span class="n">dataloader_params</span> <span class="o">=</span> <span class="n">_process_dataloader_params</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">,</span> <span class="n">dataset</span><span class="p">,</span> <span class="n">train</span><span class="p">)</span>
  152. <span class="n">dataloader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">dataset</span><span class="o">=</span><span class="n">dataset</span><span class="p">,</span> <span class="o">**</span><span class="n">dataloader_params</span><span class="p">)</span>
  153. <span class="n">dataloader</span><span class="o">.</span><span class="n">dataloader_params</span> <span class="o">=</span> <span class="n">dataloader_params</span>
  154. <span class="k">return</span> <span class="n">dataloader</span></div>
  155. <span class="k">def</span> <span class="nf">_process_dataset_params</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="n">dataset_params</span><span class="p">,</span> <span class="n">train</span><span class="p">):</span>
  156. <span class="n">default_dataset_params</span> <span class="o">=</span> <span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_dataset_params</span> <span class="k">if</span> <span class="n">train</span> <span class="k">else</span> <span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_dataset_params</span>
  157. <span class="n">default_dataset_params</span> <span class="o">=</span> <span class="n">hydra</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">instantiate</span><span class="p">(</span><span class="n">default_dataset_params</span><span class="p">)</span>
  158. <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="n">default_dataset_params</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
  159. <span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">dataset_params</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> <span class="ow">or</span> <span class="n">dataset_params</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  160. <span class="n">dataset_params</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">val</span>
  161. <span class="k">return</span> <span class="n">dataset_params</span>
  162. <span class="k">def</span> <span class="nf">_process_dataloader_params</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">,</span> <span class="n">dataset</span><span class="p">,</span> <span class="n">train</span><span class="p">):</span>
  163. <span class="n">default_dataloader_params</span> <span class="o">=</span> <span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_dataloader_params</span> <span class="k">if</span> <span class="n">train</span> <span class="k">else</span> <span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_dataloader_params</span>
  164. <span class="n">default_dataloader_params</span> <span class="o">=</span> <span class="n">hydra</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">instantiate</span><span class="p">(</span><span class="n">default_dataloader_params</span><span class="p">)</span>
  165. <span class="n">dataloader_params</span> <span class="o">=</span> <span class="n">_process_sampler_params</span><span class="p">(</span><span class="n">dataloader_params</span><span class="p">,</span> <span class="n">dataset</span><span class="p">,</span> <span class="n">default_dataloader_params</span><span class="p">)</span>
  166. <span class="k">return</span> <span class="n">dataloader_params</span>
  167. <span class="k">def</span> <span class="nf">_process_sampler_params</span><span class="p">(</span><span class="n">dataloader_params</span><span class="p">,</span> <span class="n">dataset</span><span class="p">,</span> <span class="n">default_dataloader_params</span><span class="p">):</span>
  168. <span class="n">is_dist</span> <span class="o">=</span> <span class="n">super_gradients</span><span class="o">.</span><span class="n">is_distributed</span><span class="p">()</span>
  169. <span class="k">if</span> <span class="n">get_param</span><span class="p">(</span><span class="n">dataloader_params</span><span class="p">,</span> <span class="s2">&quot;sampler&quot;</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  170. <span class="n">dataloader_params</span> <span class="o">=</span> <span class="n">_instantiate_sampler</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">)</span>
  171. <span class="k">elif</span> <span class="n">get_param</span><span class="p">(</span><span class="n">default_dataloader_params</span><span class="p">,</span> <span class="s2">&quot;sampler&quot;</span><span class="p">)</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  172. <span class="n">default_dataloader_params</span> <span class="o">=</span> <span class="n">_instantiate_sampler</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">default_dataloader_params</span><span class="p">)</span>
  173. <span class="k">elif</span> <span class="n">is_dist</span><span class="p">:</span>
  174. <span class="n">default_dataloader_params</span><span class="p">[</span><span class="s2">&quot;sampler&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;DistributedSampler&quot;</span><span class="p">:</span> <span class="p">{}}</span>
  175. <span class="n">default_dataloader_params</span> <span class="o">=</span> <span class="n">_instantiate_sampler</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">default_dataloader_params</span><span class="p">)</span>
  176. <span class="n">dataloader_params</span> <span class="o">=</span> <span class="n">override_default_params_without_nones</span><span class="p">(</span><span class="n">dataloader_params</span><span class="p">,</span> <span class="n">default_dataloader_params</span><span class="p">)</span>
  177. <span class="k">if</span> <span class="n">get_param</span><span class="p">(</span><span class="n">dataloader_params</span><span class="p">,</span> <span class="s2">&quot;batch_sampler&quot;</span><span class="p">):</span>
  178. <span class="n">sampler</span> <span class="o">=</span> <span class="n">dataloader_params</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;sampler&quot;</span><span class="p">)</span>
  179. <span class="n">batch_size</span> <span class="o">=</span> <span class="n">dataloader_params</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;batch_size&quot;</span><span class="p">)</span>
  180. <span class="k">if</span> <span class="s2">&quot;drop_last&quot;</span> <span class="ow">in</span> <span class="n">dataloader_params</span><span class="p">:</span>
  181. <span class="n">drop_last</span> <span class="o">=</span> <span class="n">dataloader_params</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;drop_last&quot;</span><span class="p">)</span>
  182. <span class="k">else</span><span class="p">:</span>
  183. <span class="n">drop_last</span> <span class="o">=</span> <span class="n">default_dataloader_params</span><span class="p">[</span><span class="s2">&quot;drop_last&quot;</span><span class="p">]</span>
  184. <span class="n">dataloader_params</span><span class="p">[</span><span class="s2">&quot;batch_sampler&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">BatchSampler</span><span class="p">(</span><span class="n">sampler</span><span class="o">=</span><span class="n">sampler</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">drop_last</span><span class="o">=</span><span class="n">drop_last</span><span class="p">)</span>
  185. <span class="k">return</span> <span class="n">dataloader_params</span>
  186. <span class="k">def</span> <span class="nf">_instantiate_sampler</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">):</span>
  187. <span class="n">sampler_name</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">dataloader_params</span><span class="p">[</span><span class="s2">&quot;sampler&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">keys</span><span class="p">())[</span><span class="mi">0</span><span class="p">]</span>
  188. <span class="n">dataloader_params</span><span class="p">[</span><span class="s2">&quot;sampler&quot;</span><span class="p">][</span><span class="n">sampler_name</span><span class="p">][</span><span class="s2">&quot;dataset&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">dataset</span>
  189. <span class="n">dataloader_params</span><span class="p">[</span><span class="s2">&quot;sampler&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">SamplersFactory</span><span class="p">()</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">dataloader_params</span><span class="p">[</span><span class="s2">&quot;sampler&quot;</span><span class="p">])</span>
  190. <span class="k">return</span> <span class="n">dataloader_params</span>
  191. <div class="viewcode-block" id="coco2017_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.coco2017_train">[docs]</a><span class="k">def</span> <span class="nf">coco2017_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  192. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  193. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;coco_detection_dataset_params&quot;</span><span class="p">,</span>
  194. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">COCODetectionDataset</span><span class="p">,</span>
  195. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  196. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  197. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  198. <span class="p">)</span></div>
  199. <div class="viewcode-block" id="coco2017_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.coco2017_val">[docs]</a><span class="k">def</span> <span class="nf">coco2017_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  200. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  201. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;coco_detection_dataset_params&quot;</span><span class="p">,</span>
  202. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">COCODetectionDataset</span><span class="p">,</span>
  203. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  204. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  205. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  206. <span class="p">)</span></div>
  207. <div class="viewcode-block" id="coco2017_train_yolox"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.coco2017_train_yolox">[docs]</a><span class="k">def</span> <span class="nf">coco2017_train_yolox</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  208. <span class="k">return</span> <span class="n">coco2017_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">)</span></div>
  209. <div class="viewcode-block" id="coco2017_val_yolox"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.coco2017_val_yolox">[docs]</a><span class="k">def</span> <span class="nf">coco2017_val_yolox</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  210. <span class="k">return</span> <span class="n">coco2017_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">)</span></div>
  211. <div class="viewcode-block" id="coco2017_train_ssd_lite_mobilenet_v2"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.coco2017_train_ssd_lite_mobilenet_v2">[docs]</a><span class="k">def</span> <span class="nf">coco2017_train_ssd_lite_mobilenet_v2</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  212. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  213. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;coco_detection_ssd_lite_mobilenet_v2_dataset_params&quot;</span><span class="p">,</span>
  214. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">COCODetectionDataset</span><span class="p">,</span>
  215. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  216. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  217. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  218. <span class="p">)</span></div>
  219. <div class="viewcode-block" id="coco2017_val_ssd_lite_mobilenet_v2"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.coco2017_val_ssd_lite_mobilenet_v2">[docs]</a><span class="k">def</span> <span class="nf">coco2017_val_ssd_lite_mobilenet_v2</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  220. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  221. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;coco_detection_ssd_lite_mobilenet_v2_dataset_params&quot;</span><span class="p">,</span>
  222. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">COCODetectionDataset</span><span class="p">,</span>
  223. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  224. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  225. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  226. <span class="p">)</span></div>
  227. <div class="viewcode-block" id="imagenet_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_train">[docs]</a><span class="k">def</span> <span class="nf">imagenet_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_dataset_params&quot;</span><span class="p">):</span>
  228. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  229. <span class="n">config_name</span><span class="o">=</span><span class="n">config_name</span><span class="p">,</span>
  230. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">ImageNetDataset</span><span class="p">,</span>
  231. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  232. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  233. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  234. <span class="p">)</span></div>
  235. <div class="viewcode-block" id="imagenet_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_val">[docs]</a><span class="k">def</span> <span class="nf">imagenet_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_dataset_params&quot;</span><span class="p">):</span>
  236. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  237. <span class="n">config_name</span><span class="o">=</span><span class="n">config_name</span><span class="p">,</span>
  238. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">ImageNetDataset</span><span class="p">,</span>
  239. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  240. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  241. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  242. <span class="p">)</span></div>
  243. <div class="viewcode-block" id="imagenet_efficientnet_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_efficientnet_train">[docs]</a><span class="k">def</span> <span class="nf">imagenet_efficientnet_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  244. <span class="k">return</span> <span class="n">imagenet_train</span><span class="p">(</span>
  245. <span class="n">dataset_params</span><span class="p">,</span>
  246. <span class="n">dataloader_params</span><span class="p">,</span>
  247. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_efficientnet_dataset_params&quot;</span><span class="p">,</span>
  248. <span class="p">)</span></div>
  249. <div class="viewcode-block" id="imagenet_efficientnet_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_efficientnet_val">[docs]</a><span class="k">def</span> <span class="nf">imagenet_efficientnet_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  250. <span class="k">return</span> <span class="n">imagenet_val</span><span class="p">(</span>
  251. <span class="n">dataset_params</span><span class="p">,</span>
  252. <span class="n">dataloader_params</span><span class="p">,</span>
  253. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_efficientnet_dataset_params&quot;</span><span class="p">,</span>
  254. <span class="p">)</span></div>
  255. <div class="viewcode-block" id="imagenet_mobilenetv2_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_mobilenetv2_train">[docs]</a><span class="k">def</span> <span class="nf">imagenet_mobilenetv2_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  256. <span class="k">return</span> <span class="n">imagenet_train</span><span class="p">(</span>
  257. <span class="n">dataset_params</span><span class="p">,</span>
  258. <span class="n">dataloader_params</span><span class="p">,</span>
  259. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_mobilenetv2_dataset_params&quot;</span><span class="p">,</span>
  260. <span class="p">)</span></div>
  261. <div class="viewcode-block" id="imagenet_mobilenetv2_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_mobilenetv2_val">[docs]</a><span class="k">def</span> <span class="nf">imagenet_mobilenetv2_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  262. <span class="k">return</span> <span class="n">imagenet_val</span><span class="p">(</span>
  263. <span class="n">dataset_params</span><span class="p">,</span>
  264. <span class="n">dataloader_params</span><span class="p">,</span>
  265. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_mobilenetv2_dataset_params&quot;</span><span class="p">,</span>
  266. <span class="p">)</span></div>
  267. <div class="viewcode-block" id="imagenet_mobilenetv3_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_mobilenetv3_train">[docs]</a><span class="k">def</span> <span class="nf">imagenet_mobilenetv3_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  268. <span class="k">return</span> <span class="n">imagenet_train</span><span class="p">(</span>
  269. <span class="n">dataset_params</span><span class="p">,</span>
  270. <span class="n">dataloader_params</span><span class="p">,</span>
  271. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_mobilenetv3_dataset_params&quot;</span><span class="p">,</span>
  272. <span class="p">)</span></div>
  273. <div class="viewcode-block" id="imagenet_mobilenetv3_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_mobilenetv3_val">[docs]</a><span class="k">def</span> <span class="nf">imagenet_mobilenetv3_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  274. <span class="k">return</span> <span class="n">imagenet_val</span><span class="p">(</span>
  275. <span class="n">dataset_params</span><span class="p">,</span>
  276. <span class="n">dataloader_params</span><span class="p">,</span>
  277. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_mobilenetv3_dataset_params&quot;</span><span class="p">,</span>
  278. <span class="p">)</span></div>
  279. <div class="viewcode-block" id="imagenet_regnetY_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_regnetY_train">[docs]</a><span class="k">def</span> <span class="nf">imagenet_regnetY_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  280. <span class="k">return</span> <span class="n">imagenet_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">,</span> <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_regnetY_dataset_params&quot;</span><span class="p">)</span></div>
  281. <div class="viewcode-block" id="imagenet_regnetY_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_regnetY_val">[docs]</a><span class="k">def</span> <span class="nf">imagenet_regnetY_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  282. <span class="k">return</span> <span class="n">imagenet_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">,</span> <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_regnetY_dataset_params&quot;</span><span class="p">)</span></div>
  283. <div class="viewcode-block" id="imagenet_resnet50_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_resnet50_train">[docs]</a><span class="k">def</span> <span class="nf">imagenet_resnet50_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  284. <span class="k">return</span> <span class="n">imagenet_train</span><span class="p">(</span>
  285. <span class="n">dataset_params</span><span class="p">,</span>
  286. <span class="n">dataloader_params</span><span class="p">,</span>
  287. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_resnet50_dataset_params&quot;</span><span class="p">,</span>
  288. <span class="p">)</span></div>
  289. <div class="viewcode-block" id="imagenet_resnet50_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_resnet50_val">[docs]</a><span class="k">def</span> <span class="nf">imagenet_resnet50_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  290. <span class="k">return</span> <span class="n">imagenet_val</span><span class="p">(</span>
  291. <span class="n">dataset_params</span><span class="p">,</span>
  292. <span class="n">dataloader_params</span><span class="p">,</span>
  293. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_resnet50_dataset_params&quot;</span><span class="p">,</span>
  294. <span class="p">)</span></div>
  295. <div class="viewcode-block" id="imagenet_resnet50_kd_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_resnet50_kd_train">[docs]</a><span class="k">def</span> <span class="nf">imagenet_resnet50_kd_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  296. <span class="k">return</span> <span class="n">imagenet_train</span><span class="p">(</span>
  297. <span class="n">dataset_params</span><span class="p">,</span>
  298. <span class="n">dataloader_params</span><span class="p">,</span>
  299. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_resnet50_kd_dataset_params&quot;</span><span class="p">,</span>
  300. <span class="p">)</span></div>
  301. <div class="viewcode-block" id="imagenet_resnet50_kd_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_resnet50_kd_val">[docs]</a><span class="k">def</span> <span class="nf">imagenet_resnet50_kd_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  302. <span class="k">return</span> <span class="n">imagenet_val</span><span class="p">(</span>
  303. <span class="n">dataset_params</span><span class="p">,</span>
  304. <span class="n">dataloader_params</span><span class="p">,</span>
  305. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_resnet50_kd_dataset_params&quot;</span><span class="p">,</span>
  306. <span class="p">)</span></div>
  307. <div class="viewcode-block" id="imagenet_vit_base_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_vit_base_train">[docs]</a><span class="k">def</span> <span class="nf">imagenet_vit_base_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  308. <span class="k">return</span> <span class="n">imagenet_train</span><span class="p">(</span>
  309. <span class="n">dataset_params</span><span class="p">,</span>
  310. <span class="n">dataloader_params</span><span class="p">,</span>
  311. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_vit_base_dataset_params&quot;</span><span class="p">,</span>
  312. <span class="p">)</span></div>
  313. <div class="viewcode-block" id="imagenet_vit_base_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.imagenet_vit_base_val">[docs]</a><span class="k">def</span> <span class="nf">imagenet_vit_base_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  314. <span class="k">return</span> <span class="n">imagenet_val</span><span class="p">(</span>
  315. <span class="n">dataset_params</span><span class="p">,</span>
  316. <span class="n">dataloader_params</span><span class="p">,</span>
  317. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;imagenet_vit_base_dataset_params&quot;</span><span class="p">,</span>
  318. <span class="p">)</span></div>
  319. <div class="viewcode-block" id="tiny_imagenet_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.tiny_imagenet_train">[docs]</a><span class="k">def</span> <span class="nf">tiny_imagenet_train</span><span class="p">(</span>
  320. <span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
  321. <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
  322. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;tiny_imagenet_dataset_params&quot;</span><span class="p">,</span>
  323. <span class="p">):</span>
  324. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  325. <span class="n">config_name</span><span class="o">=</span><span class="n">config_name</span><span class="p">,</span>
  326. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">ImageNetDataset</span><span class="p">,</span>
  327. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  328. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  329. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  330. <span class="p">)</span></div>
  331. <div class="viewcode-block" id="tiny_imagenet_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.tiny_imagenet_val">[docs]</a><span class="k">def</span> <span class="nf">tiny_imagenet_val</span><span class="p">(</span>
  332. <span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
  333. <span class="n">dataloader_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
  334. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;tiny_imagenet_dataset_params&quot;</span><span class="p">,</span>
  335. <span class="p">):</span>
  336. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  337. <span class="n">config_name</span><span class="o">=</span><span class="n">config_name</span><span class="p">,</span>
  338. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">ImageNetDataset</span><span class="p">,</span>
  339. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  340. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  341. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  342. <span class="p">)</span></div>
  343. <div class="viewcode-block" id="cifar10_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cifar10_train">[docs]</a><span class="k">def</span> <span class="nf">cifar10_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  344. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  345. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cifar10_dataset_params&quot;</span><span class="p">,</span>
  346. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">Cifar10</span><span class="p">,</span>
  347. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  348. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  349. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  350. <span class="p">)</span></div>
  351. <div class="viewcode-block" id="cifar10_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cifar10_val">[docs]</a><span class="k">def</span> <span class="nf">cifar10_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  352. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  353. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cifar10_dataset_params&quot;</span><span class="p">,</span>
  354. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">Cifar10</span><span class="p">,</span>
  355. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  356. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  357. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  358. <span class="p">)</span></div>
  359. <div class="viewcode-block" id="cifar100_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cifar100_train">[docs]</a><span class="k">def</span> <span class="nf">cifar100_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  360. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  361. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cifar100_dataset_params&quot;</span><span class="p">,</span>
  362. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">Cifar100</span><span class="p">,</span>
  363. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  364. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  365. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  366. <span class="p">)</span></div>
  367. <div class="viewcode-block" id="cifar100_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cifar100_val">[docs]</a><span class="k">def</span> <span class="nf">cifar100_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  368. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  369. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cifar100_dataset_params&quot;</span><span class="p">,</span>
  370. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">Cifar100</span><span class="p">,</span>
  371. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  372. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  373. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  374. <span class="p">)</span></div>
  375. <span class="k">def</span> <span class="nf">classification_test_dataloader</span><span class="p">(</span><span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span> <span class="n">image_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span> <span class="n">dataset_size</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DataLoader</span><span class="p">:</span>
  376. <span class="n">dataset_size</span> <span class="o">=</span> <span class="n">dataset_size</span> <span class="ow">or</span> <span class="n">batch_size</span>
  377. <span class="n">images</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">dataset_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">)))</span>
  378. <span class="n">ground_truth</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">dataset_size</span><span class="p">)))</span>
  379. <span class="n">dataset</span> <span class="o">=</span> <span class="n">TensorDataset</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">ground_truth</span><span class="p">)</span>
  380. <span class="k">return</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">dataset</span><span class="o">=</span><span class="n">dataset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span>
  381. <span class="k">def</span> <span class="nf">detection_test_dataloader</span><span class="p">(</span><span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span> <span class="n">image_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">320</span><span class="p">,</span> <span class="n">dataset_size</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DataLoader</span><span class="p">:</span>
  382. <span class="n">dataset_size</span> <span class="o">=</span> <span class="n">dataset_size</span> <span class="ow">or</span> <span class="n">batch_size</span>
  383. <span class="n">images</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">dataset_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">)))</span>
  384. <span class="n">ground_truth</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">dataset_size</span><span class="p">,</span> <span class="mi">6</span><span class="p">)))</span>
  385. <span class="n">dataset</span> <span class="o">=</span> <span class="n">TensorDataset</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">ground_truth</span><span class="p">)</span>
  386. <span class="k">return</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">dataset</span><span class="o">=</span><span class="n">dataset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span>
  387. <span class="k">def</span> <span class="nf">segmentation_test_dataloader</span><span class="p">(</span><span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span> <span class="n">image_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">512</span><span class="p">,</span> <span class="n">dataset_size</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DataLoader</span><span class="p">:</span>
  388. <span class="n">dataset_size</span> <span class="o">=</span> <span class="n">dataset_size</span> <span class="ow">or</span> <span class="n">batch_size</span>
  389. <span class="n">images</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">dataset_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">)))</span>
  390. <span class="n">ground_truth</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">dataset_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">)))</span>
  391. <span class="n">dataset</span> <span class="o">=</span> <span class="n">TensorDataset</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">ground_truth</span><span class="p">)</span>
  392. <span class="k">return</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">dataset</span><span class="o">=</span><span class="n">dataset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span>
  393. <div class="viewcode-block" id="cityscapes_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cityscapes_train">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  394. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  395. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cityscapes_dataset_params&quot;</span><span class="p">,</span>
  396. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CityscapesDataset</span><span class="p">,</span>
  397. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  398. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  399. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  400. <span class="p">)</span></div>
  401. <div class="viewcode-block" id="cityscapes_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cityscapes_val">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  402. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  403. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cityscapes_dataset_params&quot;</span><span class="p">,</span>
  404. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CityscapesDataset</span><span class="p">,</span>
  405. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  406. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  407. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  408. <span class="p">)</span></div>
  409. <div class="viewcode-block" id="cityscapes_stdc_seg50_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cityscapes_stdc_seg50_train">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_stdc_seg50_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  410. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  411. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cityscapes_stdc_seg50_dataset_params&quot;</span><span class="p">,</span>
  412. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CityscapesDataset</span><span class="p">,</span>
  413. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  414. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  415. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  416. <span class="p">)</span></div>
  417. <div class="viewcode-block" id="cityscapes_stdc_seg50_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cityscapes_stdc_seg50_val">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_stdc_seg50_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  418. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  419. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cityscapes_stdc_seg50_dataset_params&quot;</span><span class="p">,</span>
  420. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CityscapesDataset</span><span class="p">,</span>
  421. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  422. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  423. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  424. <span class="p">)</span></div>
  425. <div class="viewcode-block" id="cityscapes_stdc_seg75_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cityscapes_stdc_seg75_train">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_stdc_seg75_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  426. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  427. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cityscapes_stdc_seg75_dataset_params&quot;</span><span class="p">,</span>
  428. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CityscapesDataset</span><span class="p">,</span>
  429. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  430. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  431. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  432. <span class="p">)</span></div>
  433. <div class="viewcode-block" id="cityscapes_stdc_seg75_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cityscapes_stdc_seg75_val">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_stdc_seg75_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  434. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  435. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cityscapes_stdc_seg75_dataset_params&quot;</span><span class="p">,</span>
  436. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CityscapesDataset</span><span class="p">,</span>
  437. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  438. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  439. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  440. <span class="p">)</span></div>
  441. <div class="viewcode-block" id="cityscapes_regseg48_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cityscapes_regseg48_train">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_regseg48_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  442. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  443. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cityscapes_regseg48_dataset_params&quot;</span><span class="p">,</span>
  444. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CityscapesDataset</span><span class="p">,</span>
  445. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  446. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  447. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  448. <span class="p">)</span></div>
  449. <div class="viewcode-block" id="cityscapes_regseg48_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cityscapes_regseg48_val">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_regseg48_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  450. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  451. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cityscapes_regseg48_dataset_params&quot;</span><span class="p">,</span>
  452. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CityscapesDataset</span><span class="p">,</span>
  453. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  454. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  455. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  456. <span class="p">)</span></div>
  457. <div class="viewcode-block" id="cityscapes_ddrnet_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cityscapes_ddrnet_train">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_ddrnet_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  458. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  459. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cityscapes_ddrnet_dataset_params&quot;</span><span class="p">,</span>
  460. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CityscapesDataset</span><span class="p">,</span>
  461. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  462. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  463. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  464. <span class="p">)</span></div>
  465. <div class="viewcode-block" id="cityscapes_ddrnet_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.cityscapes_ddrnet_val">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_ddrnet_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  466. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  467. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;cityscapes_ddrnet_dataset_params&quot;</span><span class="p">,</span>
  468. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CityscapesDataset</span><span class="p">,</span>
  469. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  470. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  471. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  472. <span class="p">)</span></div>
  473. <div class="viewcode-block" id="coco_segmentation_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.coco_segmentation_train">[docs]</a><span class="k">def</span> <span class="nf">coco_segmentation_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  474. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  475. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;coco_segmentation_dataset_params&quot;</span><span class="p">,</span>
  476. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CoCoSegmentationDataSet</span><span class="p">,</span>
  477. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  478. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  479. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  480. <span class="p">)</span></div>
  481. <div class="viewcode-block" id="coco_segmentation_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.coco_segmentation_val">[docs]</a><span class="k">def</span> <span class="nf">coco_segmentation_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  482. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  483. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;coco_segmentation_dataset_params&quot;</span><span class="p">,</span>
  484. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">CoCoSegmentationDataSet</span><span class="p">,</span>
  485. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  486. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  487. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  488. <span class="p">)</span></div>
  489. <div class="viewcode-block" id="pascal_aug_segmentation_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.pascal_aug_segmentation_train">[docs]</a><span class="k">def</span> <span class="nf">pascal_aug_segmentation_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  490. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  491. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;pascal_aug_segmentation_dataset_params&quot;</span><span class="p">,</span>
  492. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">PascalVOCAndAUGUnifiedDataset</span><span class="p">,</span>
  493. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  494. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  495. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  496. <span class="p">)</span></div>
  497. <div class="viewcode-block" id="pascal_aug_segmentation_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.pascal_aug_segmentation_val">[docs]</a><span class="k">def</span> <span class="nf">pascal_aug_segmentation_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  498. <span class="k">return</span> <span class="n">pascal_voc_segmentation_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">)</span></div>
  499. <div class="viewcode-block" id="pascal_voc_segmentation_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.pascal_voc_segmentation_train">[docs]</a><span class="k">def</span> <span class="nf">pascal_voc_segmentation_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  500. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  501. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;pascal_voc_segmentation_dataset_params&quot;</span><span class="p">,</span>
  502. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">PascalVOC2012SegmentationDataSet</span><span class="p">,</span>
  503. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  504. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  505. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  506. <span class="p">)</span></div>
  507. <div class="viewcode-block" id="pascal_voc_segmentation_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.pascal_voc_segmentation_val">[docs]</a><span class="k">def</span> <span class="nf">pascal_voc_segmentation_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  508. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  509. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;pascal_voc_segmentation_dataset_params&quot;</span><span class="p">,</span>
  510. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">PascalVOC2012SegmentationDataSet</span><span class="p">,</span>
  511. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  512. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  513. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  514. <span class="p">)</span></div>
  515. <div class="viewcode-block" id="supervisely_persons_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.supervisely_persons_train">[docs]</a><span class="k">def</span> <span class="nf">supervisely_persons_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  516. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  517. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;supervisely_persons_dataset_params&quot;</span><span class="p">,</span>
  518. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">SuperviselyPersonsDataset</span><span class="p">,</span>
  519. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  520. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  521. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  522. <span class="p">)</span></div>
  523. <div class="viewcode-block" id="supervisely_persons_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.supervisely_persons_val">[docs]</a><span class="k">def</span> <span class="nf">supervisely_persons_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  524. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  525. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;supervisely_persons_dataset_params&quot;</span><span class="p">,</span>
  526. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">SuperviselyPersonsDataset</span><span class="p">,</span>
  527. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  528. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  529. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  530. <span class="p">)</span></div>
  531. <div class="viewcode-block" id="pascal_voc_detection_train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.pascal_voc_detection_train">[docs]</a><span class="k">def</span> <span class="nf">pascal_voc_detection_train</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  532. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  533. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;pascal_voc_detection_dataset_params&quot;</span><span class="p">,</span>
  534. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">PascalVOCUnifiedDetectionTrainDataset</span><span class="p">,</span>
  535. <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  536. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  537. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  538. <span class="p">)</span></div>
  539. <div class="viewcode-block" id="pascal_voc_detection_val"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.pascal_voc_detection_val">[docs]</a><span class="k">def</span> <span class="nf">pascal_voc_detection_val</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  540. <span class="k">return</span> <span class="n">get_data_loader</span><span class="p">(</span>
  541. <span class="n">config_name</span><span class="o">=</span><span class="s2">&quot;pascal_voc_detection_dataset_params&quot;</span><span class="p">,</span>
  542. <span class="n">dataset_cls</span><span class="o">=</span><span class="n">PascalVOCDetectionDataset</span><span class="p">,</span>
  543. <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  544. <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  545. <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">,</span>
  546. <span class="p">)</span></div>
  547. <span class="n">ALL_DATALOADERS</span> <span class="o">=</span> <span class="p">{</span>
  548. <span class="s2">&quot;coco2017_train&quot;</span><span class="p">:</span> <span class="n">coco2017_train</span><span class="p">,</span>
  549. <span class="s2">&quot;coco2017_val&quot;</span><span class="p">:</span> <span class="n">coco2017_val</span><span class="p">,</span>
  550. <span class="s2">&quot;coco2017_train_yolox&quot;</span><span class="p">:</span> <span class="n">coco2017_train_yolox</span><span class="p">,</span>
  551. <span class="s2">&quot;coco2017_val_yolox&quot;</span><span class="p">:</span> <span class="n">coco2017_val_yolox</span><span class="p">,</span>
  552. <span class="s2">&quot;coco2017_train_ssd_lite_mobilenet_v2&quot;</span><span class="p">:</span> <span class="n">coco2017_train_ssd_lite_mobilenet_v2</span><span class="p">,</span>
  553. <span class="s2">&quot;coco2017_val_ssd_lite_mobilenet_v2&quot;</span><span class="p">:</span> <span class="n">coco2017_val_ssd_lite_mobilenet_v2</span><span class="p">,</span>
  554. <span class="s2">&quot;imagenet_train&quot;</span><span class="p">:</span> <span class="n">imagenet_train</span><span class="p">,</span>
  555. <span class="s2">&quot;imagenet_val&quot;</span><span class="p">:</span> <span class="n">imagenet_val</span><span class="p">,</span>
  556. <span class="s2">&quot;imagenet_efficientnet_train&quot;</span><span class="p">:</span> <span class="n">imagenet_efficientnet_train</span><span class="p">,</span>
  557. <span class="s2">&quot;imagenet_efficientnet_val&quot;</span><span class="p">:</span> <span class="n">imagenet_efficientnet_val</span><span class="p">,</span>
  558. <span class="s2">&quot;imagenet_mobilenetv2_train&quot;</span><span class="p">:</span> <span class="n">imagenet_mobilenetv2_train</span><span class="p">,</span>
  559. <span class="s2">&quot;imagenet_mobilenetv2_val&quot;</span><span class="p">:</span> <span class="n">imagenet_mobilenetv2_val</span><span class="p">,</span>
  560. <span class="s2">&quot;imagenet_mobilenetv3_train&quot;</span><span class="p">:</span> <span class="n">imagenet_mobilenetv3_train</span><span class="p">,</span>
  561. <span class="s2">&quot;imagenet_mobilenetv3_val&quot;</span><span class="p">:</span> <span class="n">imagenet_mobilenetv3_val</span><span class="p">,</span>
  562. <span class="s2">&quot;imagenet_regnetY_train&quot;</span><span class="p">:</span> <span class="n">imagenet_regnetY_train</span><span class="p">,</span>
  563. <span class="s2">&quot;imagenet_regnetY_val&quot;</span><span class="p">:</span> <span class="n">imagenet_regnetY_val</span><span class="p">,</span>
  564. <span class="s2">&quot;imagenet_resnet50_train&quot;</span><span class="p">:</span> <span class="n">imagenet_resnet50_train</span><span class="p">,</span>
  565. <span class="s2">&quot;imagenet_resnet50_val&quot;</span><span class="p">:</span> <span class="n">imagenet_resnet50_val</span><span class="p">,</span>
  566. <span class="s2">&quot;imagenet_resnet50_kd_train&quot;</span><span class="p">:</span> <span class="n">imagenet_resnet50_kd_train</span><span class="p">,</span>
  567. <span class="s2">&quot;imagenet_resnet50_kd_val&quot;</span><span class="p">:</span> <span class="n">imagenet_resnet50_kd_val</span><span class="p">,</span>
  568. <span class="s2">&quot;imagenet_vit_base_train&quot;</span><span class="p">:</span> <span class="n">imagenet_vit_base_train</span><span class="p">,</span>
  569. <span class="s2">&quot;imagenet_vit_base_val&quot;</span><span class="p">:</span> <span class="n">imagenet_vit_base_val</span><span class="p">,</span>
  570. <span class="s2">&quot;tiny_imagenet_train&quot;</span><span class="p">:</span> <span class="n">tiny_imagenet_train</span><span class="p">,</span>
  571. <span class="s2">&quot;tiny_imagenet_val&quot;</span><span class="p">:</span> <span class="n">tiny_imagenet_val</span><span class="p">,</span>
  572. <span class="s2">&quot;cifar10_train&quot;</span><span class="p">:</span> <span class="n">cifar10_train</span><span class="p">,</span>
  573. <span class="s2">&quot;cifar10_val&quot;</span><span class="p">:</span> <span class="n">cifar10_val</span><span class="p">,</span>
  574. <span class="s2">&quot;cifar100_train&quot;</span><span class="p">:</span> <span class="n">cifar100_train</span><span class="p">,</span>
  575. <span class="s2">&quot;cifar100_val&quot;</span><span class="p">:</span> <span class="n">cifar100_val</span><span class="p">,</span>
  576. <span class="s2">&quot;cityscapes_train&quot;</span><span class="p">:</span> <span class="n">cityscapes_train</span><span class="p">,</span>
  577. <span class="s2">&quot;cityscapes_val&quot;</span><span class="p">:</span> <span class="n">cityscapes_val</span><span class="p">,</span>
  578. <span class="s2">&quot;cityscapes_stdc_seg50_train&quot;</span><span class="p">:</span> <span class="n">cityscapes_stdc_seg50_train</span><span class="p">,</span>
  579. <span class="s2">&quot;cityscapes_stdc_seg50_val&quot;</span><span class="p">:</span> <span class="n">cityscapes_stdc_seg50_val</span><span class="p">,</span>
  580. <span class="s2">&quot;cityscapes_stdc_seg75_train&quot;</span><span class="p">:</span> <span class="n">cityscapes_stdc_seg75_train</span><span class="p">,</span>
  581. <span class="s2">&quot;cityscapes_stdc_seg75_val&quot;</span><span class="p">:</span> <span class="n">cityscapes_stdc_seg75_val</span><span class="p">,</span>
  582. <span class="s2">&quot;cityscapes_regseg48_train&quot;</span><span class="p">:</span> <span class="n">cityscapes_regseg48_train</span><span class="p">,</span>
  583. <span class="s2">&quot;cityscapes_regseg48_val&quot;</span><span class="p">:</span> <span class="n">cityscapes_regseg48_val</span><span class="p">,</span>
  584. <span class="s2">&quot;cityscapes_ddrnet_train&quot;</span><span class="p">:</span> <span class="n">cityscapes_ddrnet_train</span><span class="p">,</span>
  585. <span class="s2">&quot;cityscapes_ddrnet_val&quot;</span><span class="p">:</span> <span class="n">cityscapes_ddrnet_val</span><span class="p">,</span>
  586. <span class="s2">&quot;coco_segmentation_train&quot;</span><span class="p">:</span> <span class="n">coco_segmentation_train</span><span class="p">,</span>
  587. <span class="s2">&quot;coco_segmentation_val&quot;</span><span class="p">:</span> <span class="n">coco_segmentation_val</span><span class="p">,</span>
  588. <span class="s2">&quot;pascal_aug_segmentation_train&quot;</span><span class="p">:</span> <span class="n">pascal_aug_segmentation_train</span><span class="p">,</span>
  589. <span class="s2">&quot;pascal_aug_segmentation_val&quot;</span><span class="p">:</span> <span class="n">pascal_aug_segmentation_val</span><span class="p">,</span>
  590. <span class="s2">&quot;pascal_voc_segmentation_train&quot;</span><span class="p">:</span> <span class="n">pascal_voc_segmentation_train</span><span class="p">,</span>
  591. <span class="s2">&quot;pascal_voc_segmentation_val&quot;</span><span class="p">:</span> <span class="n">pascal_voc_segmentation_val</span><span class="p">,</span>
  592. <span class="s2">&quot;supervisely_persons_train&quot;</span><span class="p">:</span> <span class="n">supervisely_persons_train</span><span class="p">,</span>
  593. <span class="s2">&quot;supervisely_persons_val&quot;</span><span class="p">:</span> <span class="n">supervisely_persons_val</span><span class="p">,</span>
  594. <span class="s2">&quot;pascal_voc_detection_train&quot;</span><span class="p">:</span> <span class="n">pascal_voc_detection_train</span><span class="p">,</span>
  595. <span class="s2">&quot;pascal_voc_detection_val&quot;</span><span class="p">:</span> <span class="n">pascal_voc_detection_val</span><span class="p">,</span>
  596. <span class="p">}</span>
  597. <div class="viewcode-block" id="get"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.dataloaders.get">[docs]</a><span class="k">def</span> <span class="nf">get</span><span class="p">(</span><span class="n">name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataset_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DataLoader</span><span class="p">:</span>
  598. <span class="sd">&quot;&quot;&quot;</span>
  599. <span class="sd"> Get DataLoader of the recipe-configured dataset defined by name in ALL_DATALOADERS.</span>
  600. <span class="sd"> :param name: dataset name in ALL_DATALOADERS.</span>
  601. <span class="sd"> :param dataset_params: dataset params that override the yaml configured defaults, then passed to the dataset_cls.__init__.</span>
  602. <span class="sd"> :param dataloader_params: DataLoader params that override the yaml configured defaults, then passed to the DataLoader.__init__</span>
  603. <span class="sd"> :param dataset: torch.utils.data.Dataset to be used instead of passing &quot;name&quot; (i.e for external dataset objects).</span>
  604. <span class="sd"> :return: initialized DataLoader.</span>
  605. <span class="sd"> &quot;&quot;&quot;</span>
  606. <span class="k">if</span> <span class="n">dataset</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  607. <span class="k">if</span> <span class="n">name</span> <span class="ow">or</span> <span class="n">dataset_params</span><span class="p">:</span>
  608. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;&#39;name&#39; and &#39;dataset_params&#39; cannot be passed with initialized dataset.&quot;</span><span class="p">)</span>
  609. <span class="n">dataloader_params</span> <span class="o">=</span> <span class="n">_process_sampler_params</span><span class="p">(</span><span class="n">dataloader_params</span><span class="p">,</span> <span class="n">dataset</span><span class="p">,</span> <span class="p">{})</span>
  610. <span class="n">dataloader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">dataset</span><span class="o">=</span><span class="n">dataset</span><span class="p">,</span> <span class="o">**</span><span class="n">dataloader_params</span><span class="p">)</span>
  611. <span class="k">elif</span> <span class="n">name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">ALL_DATALOADERS</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
  612. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Unsupported dataloader: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">name</span><span class="p">))</span>
  613. <span class="k">else</span><span class="p">:</span>
  614. <span class="n">dataloader_cls</span> <span class="o">=</span> <span class="n">ALL_DATALOADERS</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
  615. <span class="n">dataloader</span> <span class="o">=</span> <span class="n">dataloader_cls</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="n">dataloader_params</span><span class="p">)</span>
  616. <span class="k">return</span> <span class="n">dataloader</span></div>
  617. </pre></div>
  618. </div>
  619. </div>
  620. <footer>
  621. <hr/>
  622. <div role="contentinfo">
  623. <p>&#169; Copyright 2021, SuperGradients team.</p>
  624. </div>
  625. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  626. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  627. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  628. </footer>
  629. </div>
  630. </div>
  631. </section>
  632. </div>
  633. <script>
  634. jQuery(function () {
  635. SphinxRtdTheme.Navigation.enable(true);
  636. });
  637. </script>
  638. </body>
  639. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.datasets.all_datasets &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.datasets.all_datasets</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.datasets.all_datasets</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">defaultdict</span>
  84. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Type</span>
  85. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.dataset_interfaces</span> <span class="kn">import</span> <span class="n">DatasetInterface</span><span class="p">,</span> <span class="n">TestDatasetInterface</span><span class="p">,</span> \
  86. <span class="n">LibraryDatasetInterface</span><span class="p">,</span> \
  87. <span class="n">ClassificationDatasetInterface</span><span class="p">,</span> <span class="n">Cifar10DatasetInterface</span><span class="p">,</span> <span class="n">Cifar100DatasetInterface</span><span class="p">,</span> \
  88. <span class="n">ImageNetDatasetInterface</span><span class="p">,</span> <span class="n">TinyImageNetDatasetInterface</span><span class="p">,</span> <span class="n">CoCoSegmentationDatasetInterface</span><span class="p">,</span>\
  89. <span class="n">PascalAUG2012SegmentationDataSetInterface</span><span class="p">,</span> <span class="n">PascalVOC2012SegmentationDataSetInterface</span>
  90. <span class="kn">from</span> <span class="nn">super_gradients.common.data_types.enum.deep_learning_task</span> <span class="kn">import</span> <span class="n">DeepLearningTask</span>
  91. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.dataset_interfaces.dataset_interface</span> <span class="kn">import</span> <span class="n">CoCoDetectionDatasetInterface</span>
  92. <span class="n">CLASSIFICATION_DATASETS</span> <span class="o">=</span> <span class="p">{</span>
  93. <span class="s2">&quot;test_dataset&quot;</span><span class="p">:</span> <span class="n">TestDatasetInterface</span><span class="p">,</span>
  94. <span class="s2">&quot;library_dataset&quot;</span><span class="p">:</span> <span class="n">LibraryDatasetInterface</span><span class="p">,</span>
  95. <span class="s2">&quot;classification_dataset&quot;</span><span class="p">:</span> <span class="n">ClassificationDatasetInterface</span><span class="p">,</span>
  96. <span class="s2">&quot;cifar_10&quot;</span><span class="p">:</span> <span class="n">Cifar10DatasetInterface</span><span class="p">,</span>
  97. <span class="s2">&quot;cifar_100&quot;</span><span class="p">:</span> <span class="n">Cifar100DatasetInterface</span><span class="p">,</span>
  98. <span class="s2">&quot;imagenet&quot;</span><span class="p">:</span> <span class="n">ImageNetDatasetInterface</span><span class="p">,</span>
  99. <span class="s2">&quot;tiny_imagenet&quot;</span><span class="p">:</span> <span class="n">TinyImageNetDatasetInterface</span>
  100. <span class="p">}</span>
  101. <span class="n">OBJECT_DETECTION_DATASETS</span> <span class="o">=</span> <span class="p">{</span>
  102. <span class="s2">&quot;coco&quot;</span><span class="p">:</span> <span class="n">CoCoDetectionDatasetInterface</span><span class="p">,</span>
  103. <span class="p">}</span>
  104. <span class="n">SEMANTIC_SEGMENTATION_DATASETS</span> <span class="o">=</span> <span class="p">{</span>
  105. <span class="s2">&quot;coco&quot;</span><span class="p">:</span> <span class="n">CoCoSegmentationDatasetInterface</span><span class="p">,</span>
  106. <span class="s2">&quot;pascal_voc&quot;</span><span class="p">:</span> <span class="n">PascalVOC2012SegmentationDataSetInterface</span><span class="p">,</span>
  107. <span class="s2">&quot;pascal_aug&quot;</span><span class="p">:</span> <span class="n">PascalAUG2012SegmentationDataSetInterface</span>
  108. <span class="p">}</span>
  109. <div class="viewcode-block" id="DataSetDoesNotExistException"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.all_datasets.DataSetDoesNotExistException">[docs]</a><span class="k">class</span> <span class="nc">DataSetDoesNotExistException</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
  110. <span class="sd">&quot;&quot;&quot;</span>
  111. <span class="sd"> The requested dataset does not exist, or is not implemented.</span>
  112. <span class="sd"> &quot;&quot;&quot;</span>
  113. <span class="k">pass</span></div>
  114. <div class="viewcode-block" id="SgLibraryDatasets"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.all_datasets.SgLibraryDatasets">[docs]</a><span class="k">class</span> <span class="nc">SgLibraryDatasets</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
  115. <span class="sd">&quot;&quot;&quot;</span>
  116. <span class="sd"> Holds all of the different library dataset dictionaries, by DL Task mapping</span>
  117. <span class="sd"> Attributes:</span>
  118. <span class="sd"> CLASSIFICATION Dictionary of Classification Data sets</span>
  119. <span class="sd"> OBJECT_DETECTION Dictionary of Object Detection Data sets</span>
  120. <span class="sd"> SEMANTIC_SEGMENTATION Dictionary of Semantic Segmentation Data sets</span>
  121. <span class="sd"> &quot;&quot;&quot;</span>
  122. <span class="n">CLASSIFICATION</span> <span class="o">=</span> <span class="n">CLASSIFICATION_DATASETS</span>
  123. <span class="n">OBJECT_DETECTION</span> <span class="o">=</span> <span class="n">OBJECT_DETECTION_DATASETS</span>
  124. <span class="n">SEMANTIC_SEGMENTATION</span> <span class="o">=</span> <span class="n">SEMANTIC_SEGMENTATION_DATASETS</span>
  125. <span class="n">_datasets_mapping</span> <span class="o">=</span> <span class="p">{</span>
  126. <span class="n">DeepLearningTask</span><span class="o">.</span><span class="n">CLASSIFICATION</span><span class="p">:</span> <span class="n">CLASSIFICATION</span><span class="p">,</span>
  127. <span class="n">DeepLearningTask</span><span class="o">.</span><span class="n">SEMANTIC_SEGMENTATION</span><span class="p">:</span> <span class="n">SEMANTIC_SEGMENTATION</span><span class="p">,</span>
  128. <span class="n">DeepLearningTask</span><span class="o">.</span><span class="n">OBJECT_DETECTION</span><span class="p">:</span> <span class="n">OBJECT_DETECTION</span><span class="p">,</span>
  129. <span class="p">}</span>
  130. <div class="viewcode-block" id="SgLibraryDatasets.get_all_available_datasets"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.all_datasets.SgLibraryDatasets.get_all_available_datasets">[docs]</a> <span class="nd">@staticmethod</span>
  131. <span class="k">def</span> <span class="nf">get_all_available_datasets</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]:</span>
  132. <span class="sd">&quot;&quot;&quot;</span>
  133. <span class="sd"> Gets all the available datasets.</span>
  134. <span class="sd"> &quot;&quot;&quot;</span>
  135. <span class="n">all_datasets</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="n">defaultdict</span><span class="p">(</span><span class="nb">list</span><span class="p">)</span>
  136. <span class="k">for</span> <span class="n">dl_task</span><span class="p">,</span> <span class="n">task_datasets</span> <span class="ow">in</span> <span class="n">SgLibraryDatasets</span><span class="o">.</span><span class="n">_datasets_mapping</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
  137. <span class="k">for</span> <span class="n">dataset_name</span><span class="p">,</span> <span class="n">dataset_interface</span> <span class="ow">in</span> <span class="n">task_datasets</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
  138. <span class="n">all_datasets</span><span class="p">[</span><span class="n">dl_task</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dataset_name</span><span class="p">)</span>
  139. <span class="c1"># TODO: Return Dataset Metadata list from the dataset interfaces objects</span>
  140. <span class="c1"># TODO: Transform DatasetInterface -&gt; DataSetMetadata</span>
  141. <span class="k">return</span> <span class="n">all_datasets</span></div>
  142. <div class="viewcode-block" id="SgLibraryDatasets.get_dataset"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.all_datasets.SgLibraryDatasets.get_dataset">[docs]</a> <span class="nd">@staticmethod</span>
  143. <span class="k">def</span> <span class="nf">get_dataset</span><span class="p">(</span><span class="n">dl_task</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">dataset_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Type</span><span class="p">[</span><span class="n">DatasetInterface</span><span class="p">]:</span>
  144. <span class="sd">&quot;&quot;&quot;</span>
  145. <span class="sd"> Get&#39;s a dataset with a given name for a given deep learning task.</span>
  146. <span class="sd"> examp:</span>
  147. <span class="sd"> &gt;&gt;&gt; SgLibraryDatasets.get_dataset(dl_task=&#39;classification&#39;, dataset_name=&#39;cifar_100&#39;)</span>
  148. <span class="sd"> &gt;&gt;&gt; &lt;Cifar100DatasetInterface instance&gt;</span>
  149. <span class="sd"> &quot;&quot;&quot;</span>
  150. <span class="n">task_datasets</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">DatasetInterface</span><span class="p">]</span> <span class="o">=</span> <span class="n">SgLibraryDatasets</span><span class="o">.</span><span class="n">_datasets_mapping</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">dl_task</span><span class="p">)</span>
  151. <span class="k">if</span> <span class="ow">not</span> <span class="n">task_datasets</span><span class="p">:</span>
  152. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Invalid Deep Learining Task: </span><span class="si">{</span><span class="n">dl_task</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
  153. <span class="n">dataset</span><span class="p">:</span> <span class="n">DatasetInterface</span> <span class="o">=</span> <span class="n">task_datasets</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">dataset_name</span><span class="p">)</span>
  154. <span class="k">if</span> <span class="ow">not</span> <span class="n">dataset</span><span class="p">:</span>
  155. <span class="k">raise</span> <span class="n">DataSetDoesNotExistException</span><span class="p">(</span><span class="n">dataset_name</span><span class="p">)</span>
  156. <span class="k">return</span> <span class="n">dataset</span></div></div>
  157. </pre></div>
  158. </div>
  159. </div>
  160. <footer>
  161. <hr/>
  162. <div role="contentinfo">
  163. <p>&#169; Copyright 2021, SuperGradients team.</p>
  164. </div>
  165. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  166. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  167. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  168. </footer>
  169. </div>
  170. </div>
  171. </section>
  172. </div>
  173. <script>
  174. jQuery(function () {
  175. SphinxRtdTheme.Navigation.enable(true);
  176. });
  177. </script>
  178. </body>
  179. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.datasets.auto_augment &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.datasets.auto_augment</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.datasets.auto_augment</h1><div class="highlight"><pre>
  83. <span></span><span class="sd">&quot;&quot;&quot; RandAugment</span>
  84. <span class="sd">RandAugment is a variant of AutoAugment which randomly selects transformations</span>
  85. <span class="sd"> from AutoAugment to be applied on an image.</span>
  86. <span class="sd">RandomAugmentation Implementation adapted from:</span>
  87. <span class="sd"> https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py</span>
  88. <span class="sd">Papers:</span>
  89. <span class="sd"> RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719</span>
  90. <span class="sd">&quot;&quot;&quot;</span>
  91. <span class="kn">import</span> <span class="nn">random</span>
  92. <span class="kn">import</span> <span class="nn">re</span>
  93. <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span><span class="p">,</span> <span class="n">ImageOps</span><span class="p">,</span> <span class="n">ImageEnhance</span>
  94. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  95. <span class="n">_FILL</span> <span class="o">=</span> <span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">)</span>
  96. <span class="c1"># to unify the calls of the different augmentations in terms of params, all augmentations are set to work with a single</span>
  97. <span class="c1"># magnitude params, normalized according to _MAX_MAGNITUDE</span>
  98. <span class="n">_MAX_MAGNITUDE</span> <span class="o">=</span> <span class="mf">10.</span>
  99. <span class="n">_HPARAMS_DEFAULT</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span>
  100. <span class="n">translate_const</span><span class="o">=</span><span class="mi">250</span><span class="p">,</span>
  101. <span class="n">img_mean</span><span class="o">=</span><span class="n">_FILL</span><span class="p">,</span>
  102. <span class="p">)</span>
  103. <span class="c1"># Define the interpolation types</span>
  104. <span class="n">_RANDOM_INTERPOLATION</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">BILINEAR</span>
  105. <span class="k">def</span> <span class="nf">_interpolation</span><span class="p">(</span><span class="n">kwargs</span><span class="p">):</span>
  106. <span class="sd">&quot;&quot;&quot;</span>
  107. <span class="sd"> Performs Bi-Linear interpolation</span>
  108. <span class="sd"> &quot;&quot;&quot;</span>
  109. <span class="n">interpolation</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s1">&#39;resample&#39;</span><span class="p">,</span> <span class="n">Image</span><span class="o">.</span><span class="n">BILINEAR</span><span class="p">)</span>
  110. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">interpolation</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
  111. <span class="k">return</span> <span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">interpolation</span><span class="p">)</span>
  112. <span class="k">else</span><span class="p">:</span>
  113. <span class="k">return</span> <span class="n">interpolation</span>
  114. <div class="viewcode-block" id="shear_x"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.shear_x">[docs]</a><span class="k">def</span> <span class="nf">shear_x</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">factor</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  115. <span class="k">return</span> <span class="n">img</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="n">Image</span><span class="o">.</span><span class="n">AFFINE</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">factor</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
  116. <div class="viewcode-block" id="shear_y"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.shear_y">[docs]</a><span class="k">def</span> <span class="nf">shear_y</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">factor</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  117. <span class="k">return</span> <span class="n">img</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="n">Image</span><span class="o">.</span><span class="n">AFFINE</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">factor</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
  118. <div class="viewcode-block" id="translate_x_rel"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.translate_x_rel">[docs]</a><span class="k">def</span> <span class="nf">translate_x_rel</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">pct</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  119. <span class="n">pixels</span> <span class="o">=</span> <span class="n">pct</span> <span class="o">*</span> <span class="n">img</span><span class="o">.</span><span class="n">size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
  120. <span class="k">return</span> <span class="n">img</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="n">Image</span><span class="o">.</span><span class="n">AFFINE</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">pixels</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
  121. <div class="viewcode-block" id="translate_y_rel"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.translate_y_rel">[docs]</a><span class="k">def</span> <span class="nf">translate_y_rel</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">pct</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  122. <span class="n">pixels</span> <span class="o">=</span> <span class="n">pct</span> <span class="o">*</span> <span class="n">img</span><span class="o">.</span><span class="n">size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
  123. <span class="k">return</span> <span class="n">img</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="n">Image</span><span class="o">.</span><span class="n">AFFINE</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">pixels</span><span class="p">),</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
  124. <div class="viewcode-block" id="translate_x_abs"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.translate_x_abs">[docs]</a><span class="k">def</span> <span class="nf">translate_x_abs</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">pixels</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  125. <span class="k">return</span> <span class="n">img</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="n">Image</span><span class="o">.</span><span class="n">AFFINE</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">pixels</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
  126. <div class="viewcode-block" id="translate_y_abs"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.translate_y_abs">[docs]</a><span class="k">def</span> <span class="nf">translate_y_abs</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">pixels</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  127. <span class="k">return</span> <span class="n">img</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="n">Image</span><span class="o">.</span><span class="n">AFFINE</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">pixels</span><span class="p">),</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
  128. <div class="viewcode-block" id="rotate"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.rotate">[docs]</a><span class="k">def</span> <span class="nf">rotate</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">degrees</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  129. <span class="k">return</span> <span class="n">img</span><span class="o">.</span><span class="n">rotate</span><span class="p">(</span><span class="n">degrees</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
  130. <div class="viewcode-block" id="auto_contrast"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.auto_contrast">[docs]</a><span class="k">def</span> <span class="nf">auto_contrast</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="o">**</span><span class="n">__</span><span class="p">):</span>
  131. <span class="k">return</span> <span class="n">ImageOps</span><span class="o">.</span><span class="n">autocontrast</span><span class="p">(</span><span class="n">img</span><span class="p">)</span></div>
  132. <div class="viewcode-block" id="invert"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.invert">[docs]</a><span class="k">def</span> <span class="nf">invert</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="o">**</span><span class="n">__</span><span class="p">):</span>
  133. <span class="k">return</span> <span class="n">ImageOps</span><span class="o">.</span><span class="n">invert</span><span class="p">(</span><span class="n">img</span><span class="p">)</span></div>
  134. <div class="viewcode-block" id="equalize"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.equalize">[docs]</a><span class="k">def</span> <span class="nf">equalize</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="o">**</span><span class="n">__</span><span class="p">):</span>
  135. <span class="k">return</span> <span class="n">ImageOps</span><span class="o">.</span><span class="n">equalize</span><span class="p">(</span><span class="n">img</span><span class="p">)</span></div>
  136. <div class="viewcode-block" id="solarize"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.solarize">[docs]</a><span class="k">def</span> <span class="nf">solarize</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">thresh</span><span class="p">,</span> <span class="o">**</span><span class="n">__</span><span class="p">):</span>
  137. <span class="k">return</span> <span class="n">ImageOps</span><span class="o">.</span><span class="n">solarize</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">thresh</span><span class="p">)</span></div>
  138. <div class="viewcode-block" id="solarize_add"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.solarize_add">[docs]</a><span class="k">def</span> <span class="nf">solarize_add</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">add</span><span class="p">,</span> <span class="n">thresh</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="o">**</span><span class="n">__</span><span class="p">):</span>
  139. <span class="n">lut</span> <span class="o">=</span> <span class="p">[]</span>
  140. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">256</span><span class="p">):</span>
  141. <span class="k">if</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">thresh</span><span class="p">:</span>
  142. <span class="n">lut</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">min</span><span class="p">(</span><span class="mi">255</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="n">add</span><span class="p">))</span>
  143. <span class="k">else</span><span class="p">:</span>
  144. <span class="n">lut</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">i</span><span class="p">)</span>
  145. <span class="k">if</span> <span class="n">img</span><span class="o">.</span><span class="n">mode</span> <span class="ow">in</span> <span class="p">(</span><span class="s2">&quot;L&quot;</span><span class="p">,</span> <span class="s2">&quot;RGB&quot;</span><span class="p">):</span>
  146. <span class="k">if</span> <span class="n">img</span><span class="o">.</span><span class="n">mode</span> <span class="o">==</span> <span class="s2">&quot;RGB&quot;</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">lut</span><span class="p">)</span> <span class="o">==</span> <span class="mi">256</span><span class="p">:</span>
  147. <span class="n">lut</span> <span class="o">=</span> <span class="n">lut</span> <span class="o">+</span> <span class="n">lut</span> <span class="o">+</span> <span class="n">lut</span>
  148. <span class="k">return</span> <span class="n">img</span><span class="o">.</span><span class="n">point</span><span class="p">(</span><span class="n">lut</span><span class="p">)</span>
  149. <span class="k">else</span><span class="p">:</span>
  150. <span class="k">return</span> <span class="n">img</span></div>
  151. <div class="viewcode-block" id="posterize"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.posterize">[docs]</a><span class="k">def</span> <span class="nf">posterize</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">bits_to_keep</span><span class="p">,</span> <span class="o">**</span><span class="n">__</span><span class="p">):</span>
  152. <span class="k">if</span> <span class="n">bits_to_keep</span> <span class="o">&gt;=</span> <span class="mi">8</span><span class="p">:</span>
  153. <span class="k">return</span> <span class="n">img</span>
  154. <span class="k">return</span> <span class="n">ImageOps</span><span class="o">.</span><span class="n">posterize</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">bits_to_keep</span><span class="p">)</span></div>
  155. <div class="viewcode-block" id="contrast"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.contrast">[docs]</a><span class="k">def</span> <span class="nf">contrast</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">factor</span><span class="p">,</span> <span class="o">**</span><span class="n">__</span><span class="p">):</span>
  156. <span class="k">return</span> <span class="n">ImageEnhance</span><span class="o">.</span><span class="n">Contrast</span><span class="p">(</span><span class="n">img</span><span class="p">)</span><span class="o">.</span><span class="n">enhance</span><span class="p">(</span><span class="n">factor</span><span class="p">)</span></div>
  157. <div class="viewcode-block" id="color"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.color">[docs]</a><span class="k">def</span> <span class="nf">color</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">factor</span><span class="p">,</span> <span class="o">**</span><span class="n">__</span><span class="p">):</span>
  158. <span class="k">return</span> <span class="n">ImageEnhance</span><span class="o">.</span><span class="n">Color</span><span class="p">(</span><span class="n">img</span><span class="p">)</span><span class="o">.</span><span class="n">enhance</span><span class="p">(</span><span class="n">factor</span><span class="p">)</span></div>
  159. <div class="viewcode-block" id="brightness"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.brightness">[docs]</a><span class="k">def</span> <span class="nf">brightness</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">factor</span><span class="p">,</span> <span class="o">**</span><span class="n">__</span><span class="p">):</span>
  160. <span class="k">return</span> <span class="n">ImageEnhance</span><span class="o">.</span><span class="n">Brightness</span><span class="p">(</span><span class="n">img</span><span class="p">)</span><span class="o">.</span><span class="n">enhance</span><span class="p">(</span><span class="n">factor</span><span class="p">)</span></div>
  161. <div class="viewcode-block" id="sharpness"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.sharpness">[docs]</a><span class="k">def</span> <span class="nf">sharpness</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">factor</span><span class="p">,</span> <span class="o">**</span><span class="n">__</span><span class="p">):</span>
  162. <span class="k">return</span> <span class="n">ImageEnhance</span><span class="o">.</span><span class="n">Sharpness</span><span class="p">(</span><span class="n">img</span><span class="p">)</span><span class="o">.</span><span class="n">enhance</span><span class="p">(</span><span class="n">factor</span><span class="p">)</span></div>
  163. <span class="k">def</span> <span class="nf">_randomly_negate</span><span class="p">(</span><span class="n">v</span><span class="p">):</span>
  164. <span class="sd">&quot;&quot;&quot;With 50% prob, negate the value&quot;&quot;&quot;</span>
  165. <span class="k">return</span> <span class="o">-</span><span class="n">v</span> <span class="k">if</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mf">0.5</span> <span class="k">else</span> <span class="n">v</span>
  166. <span class="k">def</span> <span class="nf">_rotate_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">_hparams</span><span class="p">):</span>
  167. <span class="c1"># range [-30, 30]</span>
  168. <span class="n">level</span> <span class="o">=</span> <span class="p">(</span><span class="n">level</span> <span class="o">/</span> <span class="n">_MAX_MAGNITUDE</span><span class="p">)</span> <span class="o">*</span> <span class="mf">30.</span>
  169. <span class="n">level</span> <span class="o">=</span> <span class="n">_randomly_negate</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
  170. <span class="k">return</span> <span class="n">level</span><span class="p">,</span>
  171. <span class="k">def</span> <span class="nf">_enhance_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">_hparams</span><span class="p">):</span>
  172. <span class="c1"># range [0.1, 1.9]</span>
  173. <span class="k">return</span> <span class="p">(</span><span class="n">level</span> <span class="o">/</span> <span class="n">_MAX_MAGNITUDE</span><span class="p">)</span> <span class="o">*</span> <span class="mf">1.8</span> <span class="o">+</span> <span class="mf">0.1</span><span class="p">,</span>
  174. <span class="k">def</span> <span class="nf">_enhance_increasing_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">_hparams</span><span class="p">):</span>
  175. <span class="c1"># range [0.1, 1.9]</span>
  176. <span class="n">level</span> <span class="o">=</span> <span class="p">(</span><span class="n">level</span> <span class="o">/</span> <span class="n">_MAX_MAGNITUDE</span><span class="p">)</span> <span class="o">*</span> <span class="mf">.9</span>
  177. <span class="n">level</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">+</span> <span class="n">_randomly_negate</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
  178. <span class="k">return</span> <span class="n">level</span><span class="p">,</span>
  179. <span class="k">def</span> <span class="nf">_shear_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">_hparams</span><span class="p">):</span>
  180. <span class="c1"># range [-0.3, 0.3]</span>
  181. <span class="n">level</span> <span class="o">=</span> <span class="p">(</span><span class="n">level</span> <span class="o">/</span> <span class="n">_MAX_MAGNITUDE</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.3</span>
  182. <span class="n">level</span> <span class="o">=</span> <span class="n">_randomly_negate</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
  183. <span class="k">return</span> <span class="n">level</span><span class="p">,</span>
  184. <span class="k">def</span> <span class="nf">_translate_abs_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">hparams</span><span class="p">):</span>
  185. <span class="n">translate_const</span> <span class="o">=</span> <span class="n">hparams</span><span class="p">[</span><span class="s1">&#39;translate_const&#39;</span><span class="p">]</span>
  186. <span class="n">level</span> <span class="o">=</span> <span class="p">(</span><span class="n">level</span> <span class="o">/</span> <span class="n">_MAX_MAGNITUDE</span><span class="p">)</span> <span class="o">*</span> <span class="nb">float</span><span class="p">(</span><span class="n">translate_const</span><span class="p">)</span>
  187. <span class="n">level</span> <span class="o">=</span> <span class="n">_randomly_negate</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
  188. <span class="k">return</span> <span class="n">level</span><span class="p">,</span>
  189. <span class="k">def</span> <span class="nf">_translate_rel_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">hparams</span><span class="p">):</span>
  190. <span class="c1"># default range [-0.45, 0.45]</span>
  191. <span class="n">translate_pct</span> <span class="o">=</span> <span class="n">hparams</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;translate_pct&#39;</span><span class="p">,</span> <span class="mf">0.45</span><span class="p">)</span>
  192. <span class="n">level</span> <span class="o">=</span> <span class="p">(</span><span class="n">level</span> <span class="o">/</span> <span class="n">_MAX_MAGNITUDE</span><span class="p">)</span> <span class="o">*</span> <span class="n">translate_pct</span>
  193. <span class="n">level</span> <span class="o">=</span> <span class="n">_randomly_negate</span><span class="p">(</span><span class="n">level</span><span class="p">)</span>
  194. <span class="k">return</span> <span class="n">level</span><span class="p">,</span>
  195. <span class="k">def</span> <span class="nf">_posterize_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">_hparams</span><span class="p">):</span>
  196. <span class="c1"># As per Tensorflow TPU EfficientNet impl</span>
  197. <span class="c1"># range [0, 4], &#39;keep 0 up to 4 MSB of original image&#39;</span>
  198. <span class="c1"># intensity/severity of augmentation decreases with level</span>
  199. <span class="k">return</span> <span class="nb">int</span><span class="p">((</span><span class="n">level</span> <span class="o">/</span> <span class="n">_MAX_MAGNITUDE</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span><span class="p">),</span>
  200. <span class="k">def</span> <span class="nf">_posterize_increasing_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">hparams</span><span class="p">):</span>
  201. <span class="c1"># As per Tensorflow models research and UDA impl</span>
  202. <span class="c1"># range [4, 0], &#39;keep 4 down to 0 MSB of original image&#39;,</span>
  203. <span class="c1"># intensity/severity of augmentation increases with level</span>
  204. <span class="k">return</span> <span class="mi">4</span> <span class="o">-</span> <span class="n">_posterize_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">hparams</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span>
  205. <span class="k">def</span> <span class="nf">_posterize_original_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">_hparams</span><span class="p">):</span>
  206. <span class="c1"># As per original AutoAugment paper description</span>
  207. <span class="c1"># range [4, 8], &#39;keep 4 up to 8 MSB of image&#39;</span>
  208. <span class="c1"># intensity/severity of augmentation decreases with level</span>
  209. <span class="k">return</span> <span class="nb">int</span><span class="p">((</span><span class="n">level</span> <span class="o">/</span> <span class="n">_MAX_MAGNITUDE</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span><span class="p">)</span> <span class="o">+</span> <span class="mi">4</span><span class="p">,</span>
  210. <span class="k">def</span> <span class="nf">_solarize_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">_hparams</span><span class="p">):</span>
  211. <span class="c1"># range [0, 256]</span>
  212. <span class="c1"># intensity/severity of augmentation decreases with level</span>
  213. <span class="k">return</span> <span class="nb">int</span><span class="p">((</span><span class="n">level</span> <span class="o">/</span> <span class="n">_MAX_MAGNITUDE</span><span class="p">)</span> <span class="o">*</span> <span class="mi">256</span><span class="p">),</span>
  214. <span class="k">def</span> <span class="nf">_solarize_increasing_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">_hparams</span><span class="p">):</span>
  215. <span class="c1"># range [0, 256]</span>
  216. <span class="c1"># intensity/severity of augmentation increases with level</span>
  217. <span class="k">return</span> <span class="mi">256</span> <span class="o">-</span> <span class="n">_solarize_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">_hparams</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span>
  218. <span class="k">def</span> <span class="nf">_solarize_add_level_to_arg</span><span class="p">(</span><span class="n">level</span><span class="p">,</span> <span class="n">_hparams</span><span class="p">):</span>
  219. <span class="c1"># range [0, 110]</span>
  220. <span class="k">return</span> <span class="nb">int</span><span class="p">((</span><span class="n">level</span> <span class="o">/</span> <span class="n">_MAX_MAGNITUDE</span><span class="p">)</span> <span class="o">*</span> <span class="mi">110</span><span class="p">),</span>
  221. <span class="n">LEVEL_TO_ARG</span> <span class="o">=</span> <span class="p">{</span>
  222. <span class="s1">&#39;AutoContrast&#39;</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span>
  223. <span class="s1">&#39;Equalize&#39;</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span>
  224. <span class="s1">&#39;Invert&#39;</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span>
  225. <span class="s1">&#39;Rotate&#39;</span><span class="p">:</span> <span class="n">_rotate_level_to_arg</span><span class="p">,</span>
  226. <span class="c1"># There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers</span>
  227. <span class="s1">&#39;Posterize&#39;</span><span class="p">:</span> <span class="n">_posterize_level_to_arg</span><span class="p">,</span>
  228. <span class="s1">&#39;PosterizeIncreasing&#39;</span><span class="p">:</span> <span class="n">_posterize_increasing_level_to_arg</span><span class="p">,</span>
  229. <span class="s1">&#39;PosterizeOriginal&#39;</span><span class="p">:</span> <span class="n">_posterize_original_level_to_arg</span><span class="p">,</span>
  230. <span class="s1">&#39;Solarize&#39;</span><span class="p">:</span> <span class="n">_solarize_level_to_arg</span><span class="p">,</span>
  231. <span class="s1">&#39;SolarizeIncreasing&#39;</span><span class="p">:</span> <span class="n">_solarize_increasing_level_to_arg</span><span class="p">,</span>
  232. <span class="s1">&#39;SolarizeAdd&#39;</span><span class="p">:</span> <span class="n">_solarize_add_level_to_arg</span><span class="p">,</span>
  233. <span class="s1">&#39;Color&#39;</span><span class="p">:</span> <span class="n">_enhance_level_to_arg</span><span class="p">,</span>
  234. <span class="s1">&#39;ColorIncreasing&#39;</span><span class="p">:</span> <span class="n">_enhance_increasing_level_to_arg</span><span class="p">,</span>
  235. <span class="s1">&#39;Contrast&#39;</span><span class="p">:</span> <span class="n">_enhance_level_to_arg</span><span class="p">,</span>
  236. <span class="s1">&#39;ContrastIncreasing&#39;</span><span class="p">:</span> <span class="n">_enhance_increasing_level_to_arg</span><span class="p">,</span>
  237. <span class="s1">&#39;Brightness&#39;</span><span class="p">:</span> <span class="n">_enhance_level_to_arg</span><span class="p">,</span>
  238. <span class="s1">&#39;BrightnessIncreasing&#39;</span><span class="p">:</span> <span class="n">_enhance_increasing_level_to_arg</span><span class="p">,</span>
  239. <span class="s1">&#39;Sharpness&#39;</span><span class="p">:</span> <span class="n">_enhance_level_to_arg</span><span class="p">,</span>
  240. <span class="s1">&#39;SharpnessIncreasing&#39;</span><span class="p">:</span> <span class="n">_enhance_increasing_level_to_arg</span><span class="p">,</span>
  241. <span class="s1">&#39;ShearX&#39;</span><span class="p">:</span> <span class="n">_shear_level_to_arg</span><span class="p">,</span>
  242. <span class="s1">&#39;ShearY&#39;</span><span class="p">:</span> <span class="n">_shear_level_to_arg</span><span class="p">,</span>
  243. <span class="s1">&#39;TranslateX&#39;</span><span class="p">:</span> <span class="n">_translate_abs_level_to_arg</span><span class="p">,</span>
  244. <span class="s1">&#39;TranslateY&#39;</span><span class="p">:</span> <span class="n">_translate_abs_level_to_arg</span><span class="p">,</span>
  245. <span class="s1">&#39;TranslateXRel&#39;</span><span class="p">:</span> <span class="n">_translate_rel_level_to_arg</span><span class="p">,</span>
  246. <span class="s1">&#39;TranslateYRel&#39;</span><span class="p">:</span> <span class="n">_translate_rel_level_to_arg</span><span class="p">,</span>
  247. <span class="p">}</span>
  248. <span class="n">NAME_TO_OP</span> <span class="o">=</span> <span class="p">{</span>
  249. <span class="s1">&#39;AutoContrast&#39;</span><span class="p">:</span> <span class="n">auto_contrast</span><span class="p">,</span>
  250. <span class="s1">&#39;Equalize&#39;</span><span class="p">:</span> <span class="n">equalize</span><span class="p">,</span>
  251. <span class="s1">&#39;Invert&#39;</span><span class="p">:</span> <span class="n">invert</span><span class="p">,</span>
  252. <span class="s1">&#39;Rotate&#39;</span><span class="p">:</span> <span class="n">rotate</span><span class="p">,</span>
  253. <span class="s1">&#39;Posterize&#39;</span><span class="p">:</span> <span class="n">posterize</span><span class="p">,</span>
  254. <span class="s1">&#39;PosterizeIncreasing&#39;</span><span class="p">:</span> <span class="n">posterize</span><span class="p">,</span>
  255. <span class="s1">&#39;PosterizeOriginal&#39;</span><span class="p">:</span> <span class="n">posterize</span><span class="p">,</span>
  256. <span class="s1">&#39;Solarize&#39;</span><span class="p">:</span> <span class="n">solarize</span><span class="p">,</span>
  257. <span class="s1">&#39;SolarizeIncreasing&#39;</span><span class="p">:</span> <span class="n">solarize</span><span class="p">,</span>
  258. <span class="s1">&#39;SolarizeAdd&#39;</span><span class="p">:</span> <span class="n">solarize_add</span><span class="p">,</span>
  259. <span class="s1">&#39;Color&#39;</span><span class="p">:</span> <span class="n">color</span><span class="p">,</span>
  260. <span class="s1">&#39;ColorIncreasing&#39;</span><span class="p">:</span> <span class="n">color</span><span class="p">,</span>
  261. <span class="s1">&#39;Contrast&#39;</span><span class="p">:</span> <span class="n">contrast</span><span class="p">,</span>
  262. <span class="s1">&#39;ContrastIncreasing&#39;</span><span class="p">:</span> <span class="n">contrast</span><span class="p">,</span>
  263. <span class="s1">&#39;Brightness&#39;</span><span class="p">:</span> <span class="n">brightness</span><span class="p">,</span>
  264. <span class="s1">&#39;BrightnessIncreasing&#39;</span><span class="p">:</span> <span class="n">brightness</span><span class="p">,</span>
  265. <span class="s1">&#39;Sharpness&#39;</span><span class="p">:</span> <span class="n">sharpness</span><span class="p">,</span>
  266. <span class="s1">&#39;SharpnessIncreasing&#39;</span><span class="p">:</span> <span class="n">sharpness</span><span class="p">,</span>
  267. <span class="s1">&#39;ShearX&#39;</span><span class="p">:</span> <span class="n">shear_x</span><span class="p">,</span>
  268. <span class="s1">&#39;ShearY&#39;</span><span class="p">:</span> <span class="n">shear_y</span><span class="p">,</span>
  269. <span class="s1">&#39;TranslateX&#39;</span><span class="p">:</span> <span class="n">translate_x_abs</span><span class="p">,</span>
  270. <span class="s1">&#39;TranslateY&#39;</span><span class="p">:</span> <span class="n">translate_y_abs</span><span class="p">,</span>
  271. <span class="s1">&#39;TranslateXRel&#39;</span><span class="p">:</span> <span class="n">translate_x_rel</span><span class="p">,</span>
  272. <span class="s1">&#39;TranslateYRel&#39;</span><span class="p">:</span> <span class="n">translate_y_rel</span><span class="p">,</span>
  273. <span class="p">}</span>
  274. <div class="viewcode-block" id="AugmentOp"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.AugmentOp">[docs]</a><span class="k">class</span> <span class="nc">AugmentOp</span><span class="p">:</span>
  275. <span class="sd">&quot;&quot;&quot;</span>
  276. <span class="sd"> single auto augment operations</span>
  277. <span class="sd"> &quot;&quot;&quot;</span>
  278. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">prob</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">magnitude</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">hparams</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  279. <span class="n">hparams</span> <span class="o">=</span> <span class="n">hparams</span> <span class="ow">or</span> <span class="n">_HPARAMS_DEFAULT</span>
  280. <span class="bp">self</span><span class="o">.</span><span class="n">aug_fn</span> <span class="o">=</span> <span class="n">NAME_TO_OP</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
  281. <span class="bp">self</span><span class="o">.</span><span class="n">level_fn</span> <span class="o">=</span> <span class="n">LEVEL_TO_ARG</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
  282. <span class="bp">self</span><span class="o">.</span><span class="n">prob</span> <span class="o">=</span> <span class="n">prob</span>
  283. <span class="bp">self</span><span class="o">.</span><span class="n">magnitude</span> <span class="o">=</span> <span class="n">magnitude</span>
  284. <span class="bp">self</span><span class="o">.</span><span class="n">hparams</span> <span class="o">=</span> <span class="n">hparams</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
  285. <span class="bp">self</span><span class="o">.</span><span class="n">kwargs</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span>
  286. <span class="n">fillcolor</span><span class="o">=</span><span class="n">hparams</span><span class="p">[</span><span class="s1">&#39;img_mean&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="s1">&#39;img_mean&#39;</span> <span class="ow">in</span> <span class="n">hparams</span> <span class="k">else</span> <span class="n">_FILL</span><span class="p">,</span>
  287. <span class="n">resample</span><span class="o">=</span><span class="n">hparams</span><span class="p">[</span><span class="s1">&#39;interpolation&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="s1">&#39;interpolation&#39;</span> <span class="ow">in</span> <span class="n">hparams</span> <span class="k">else</span> <span class="n">_RANDOM_INTERPOLATION</span><span class="p">,</span>
  288. <span class="p">)</span>
  289. <span class="c1"># If magnitude_std is &gt; 0, introduce some randomness</span>
  290. <span class="bp">self</span><span class="o">.</span><span class="n">magnitude_std</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hparams</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s1">&#39;magnitude_std&#39;</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
  291. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">img</span><span class="p">):</span>
  292. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">prob</span> <span class="o">&lt;</span> <span class="mf">1.0</span> <span class="ow">and</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">prob</span><span class="p">:</span>
  293. <span class="k">return</span> <span class="n">img</span>
  294. <span class="n">magnitude</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">magnitude</span>
  295. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">magnitude_std</span><span class="p">:</span>
  296. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">magnitude_std</span> <span class="o">==</span> <span class="nb">float</span><span class="p">(</span><span class="s1">&#39;inf&#39;</span><span class="p">):</span>
  297. <span class="n">magnitude</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">magnitude</span><span class="p">)</span>
  298. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">magnitude_std</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
  299. <span class="n">magnitude</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">gauss</span><span class="p">(</span><span class="n">magnitude</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">magnitude_std</span><span class="p">)</span>
  300. <span class="n">magnitude</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">_MAX_MAGNITUDE</span><span class="p">,</span> <span class="nb">max</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">magnitude</span><span class="p">))</span> <span class="c1"># clip to valid range</span>
  301. <span class="n">level_args</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">level_fn</span><span class="p">(</span><span class="n">magnitude</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hparams</span><span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">level_fn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="nb">tuple</span><span class="p">()</span>
  302. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">aug_fn</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="o">*</span><span class="n">level_args</span><span class="p">,</span> <span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">kwargs</span><span class="p">)</span></div>
  303. <span class="n">_RAND_TRANSFORMS</span> <span class="o">=</span> <span class="p">[</span>
  304. <span class="s1">&#39;AutoContrast&#39;</span><span class="p">,</span>
  305. <span class="s1">&#39;Equalize&#39;</span><span class="p">,</span>
  306. <span class="s1">&#39;Invert&#39;</span><span class="p">,</span>
  307. <span class="s1">&#39;Rotate&#39;</span><span class="p">,</span>
  308. <span class="s1">&#39;Posterize&#39;</span><span class="p">,</span>
  309. <span class="s1">&#39;Solarize&#39;</span><span class="p">,</span>
  310. <span class="s1">&#39;SolarizeAdd&#39;</span><span class="p">,</span>
  311. <span class="s1">&#39;Color&#39;</span><span class="p">,</span>
  312. <span class="s1">&#39;Contrast&#39;</span><span class="p">,</span>
  313. <span class="s1">&#39;Brightness&#39;</span><span class="p">,</span>
  314. <span class="s1">&#39;Sharpness&#39;</span><span class="p">,</span>
  315. <span class="s1">&#39;ShearX&#39;</span><span class="p">,</span>
  316. <span class="s1">&#39;ShearY&#39;</span><span class="p">,</span>
  317. <span class="s1">&#39;TranslateXRel&#39;</span><span class="p">,</span>
  318. <span class="s1">&#39;TranslateYRel&#39;</span><span class="p">,</span>
  319. <span class="c1"># &#39;Cutout&#39; # NOTE I&#39;ve implement this as random erasing separately</span>
  320. <span class="p">]</span>
  321. <span class="n">_RAND_INCREASING_TRANSFORMS</span> <span class="o">=</span> <span class="p">[</span>
  322. <span class="s1">&#39;AutoContrast&#39;</span><span class="p">,</span>
  323. <span class="s1">&#39;Equalize&#39;</span><span class="p">,</span>
  324. <span class="s1">&#39;Invert&#39;</span><span class="p">,</span>
  325. <span class="s1">&#39;Rotate&#39;</span><span class="p">,</span>
  326. <span class="s1">&#39;PosterizeIncreasing&#39;</span><span class="p">,</span>
  327. <span class="s1">&#39;SolarizeIncreasing&#39;</span><span class="p">,</span>
  328. <span class="s1">&#39;SolarizeAdd&#39;</span><span class="p">,</span>
  329. <span class="s1">&#39;ColorIncreasing&#39;</span><span class="p">,</span>
  330. <span class="s1">&#39;ContrastIncreasing&#39;</span><span class="p">,</span>
  331. <span class="s1">&#39;BrightnessIncreasing&#39;</span><span class="p">,</span>
  332. <span class="s1">&#39;SharpnessIncreasing&#39;</span><span class="p">,</span>
  333. <span class="s1">&#39;ShearX&#39;</span><span class="p">,</span>
  334. <span class="s1">&#39;ShearY&#39;</span><span class="p">,</span>
  335. <span class="s1">&#39;TranslateXRel&#39;</span><span class="p">,</span>
  336. <span class="s1">&#39;TranslateYRel&#39;</span><span class="p">,</span>
  337. <span class="c1"># &#39;Cutout&#39; # NOTE I&#39;ve implement this as random erasing separately</span>
  338. <span class="p">]</span>
  339. <span class="c1"># These experimental weights are based loosely on the relative improvements mentioned in paper.</span>
  340. <span class="c1"># They may not result in increased performance, but could likely be tuned to so.</span>
  341. <span class="n">_RAND_CHOICE_WEIGHTS_0</span> <span class="o">=</span> <span class="p">{</span>
  342. <span class="s1">&#39;Rotate&#39;</span><span class="p">:</span> <span class="mf">0.3</span><span class="p">,</span>
  343. <span class="s1">&#39;ShearX&#39;</span><span class="p">:</span> <span class="mf">0.2</span><span class="p">,</span>
  344. <span class="s1">&#39;ShearY&#39;</span><span class="p">:</span> <span class="mf">0.2</span><span class="p">,</span>
  345. <span class="s1">&#39;TranslateXRel&#39;</span><span class="p">:</span> <span class="mf">0.1</span><span class="p">,</span>
  346. <span class="s1">&#39;TranslateYRel&#39;</span><span class="p">:</span> <span class="mf">0.1</span><span class="p">,</span>
  347. <span class="s1">&#39;Color&#39;</span><span class="p">:</span> <span class="mf">.025</span><span class="p">,</span>
  348. <span class="s1">&#39;Sharpness&#39;</span><span class="p">:</span> <span class="mf">0.025</span><span class="p">,</span>
  349. <span class="s1">&#39;AutoContrast&#39;</span><span class="p">:</span> <span class="mf">0.025</span><span class="p">,</span>
  350. <span class="s1">&#39;Solarize&#39;</span><span class="p">:</span> <span class="mf">.005</span><span class="p">,</span>
  351. <span class="s1">&#39;SolarizeAdd&#39;</span><span class="p">:</span> <span class="mf">.005</span><span class="p">,</span>
  352. <span class="s1">&#39;Contrast&#39;</span><span class="p">:</span> <span class="mf">.005</span><span class="p">,</span>
  353. <span class="s1">&#39;Brightness&#39;</span><span class="p">:</span> <span class="mf">.005</span><span class="p">,</span>
  354. <span class="s1">&#39;Equalize&#39;</span><span class="p">:</span> <span class="mf">.005</span><span class="p">,</span>
  355. <span class="s1">&#39;Posterize&#39;</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span>
  356. <span class="s1">&#39;Invert&#39;</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span>
  357. <span class="p">}</span>
  358. <span class="k">def</span> <span class="nf">_select_rand_weights</span><span class="p">(</span><span class="n">weight_idx</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">transforms</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  359. <span class="n">transforms</span> <span class="o">=</span> <span class="n">transforms</span> <span class="ow">or</span> <span class="n">_RAND_TRANSFORMS</span>
  360. <span class="k">assert</span> <span class="n">weight_idx</span> <span class="o">==</span> <span class="mi">0</span> <span class="c1"># only one set of weights currently</span>
  361. <span class="n">rand_weights</span> <span class="o">=</span> <span class="n">_RAND_CHOICE_WEIGHTS_0</span>
  362. <span class="n">probs</span> <span class="o">=</span> <span class="p">[</span><span class="n">rand_weights</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">transforms</span><span class="p">]</span>
  363. <span class="n">probs</span> <span class="o">/=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">probs</span><span class="p">)</span>
  364. <span class="k">return</span> <span class="n">probs</span>
  365. <div class="viewcode-block" id="rand_augment_ops"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.rand_augment_ops">[docs]</a><span class="k">def</span> <span class="nf">rand_augment_ops</span><span class="p">(</span><span class="n">magnitude</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">hparams</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">transforms</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  366. <span class="n">hparams</span> <span class="o">=</span> <span class="n">hparams</span> <span class="ow">or</span> <span class="n">_HPARAMS_DEFAULT</span>
  367. <span class="n">transforms</span> <span class="o">=</span> <span class="n">transforms</span> <span class="ow">or</span> <span class="n">_RAND_TRANSFORMS</span>
  368. <span class="k">return</span> <span class="p">[</span><span class="n">AugmentOp</span><span class="p">(</span>
  369. <span class="n">name</span><span class="p">,</span> <span class="n">prob</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">magnitude</span><span class="o">=</span><span class="n">magnitude</span><span class="p">,</span> <span class="n">hparams</span><span class="o">=</span><span class="n">hparams</span><span class="p">)</span> <span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">transforms</span><span class="p">]</span></div>
  370. <div class="viewcode-block" id="RandAugment"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.RandAugment">[docs]</a><span class="k">class</span> <span class="nc">RandAugment</span><span class="p">:</span>
  371. <span class="sd">&quot;&quot;&quot;</span>
  372. <span class="sd"> Random auto augment class, will select auto augment transforms according to probability weights for each op</span>
  373. <span class="sd"> &quot;&quot;&quot;</span>
  374. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ops</span><span class="p">,</span> <span class="n">num_layers</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">choice_weights</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  375. <span class="bp">self</span><span class="o">.</span><span class="n">ops</span> <span class="o">=</span> <span class="n">ops</span>
  376. <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">=</span> <span class="n">num_layers</span>
  377. <span class="bp">self</span><span class="o">.</span><span class="n">choice_weights</span> <span class="o">=</span> <span class="n">choice_weights</span>
  378. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">img</span><span class="p">):</span>
  379. <span class="c1"># no replacement when using weighted choice</span>
  380. <span class="n">ops</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span>
  381. <span class="bp">self</span><span class="o">.</span><span class="n">ops</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">,</span> <span class="n">replace</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">choice_weights</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">choice_weights</span><span class="p">)</span>
  382. <span class="k">for</span> <span class="n">op</span> <span class="ow">in</span> <span class="n">ops</span><span class="p">:</span>
  383. <span class="n">img</span> <span class="o">=</span> <span class="n">op</span><span class="p">(</span><span class="n">img</span><span class="p">)</span>
  384. <span class="k">return</span> <span class="n">img</span></div>
  385. <div class="viewcode-block" id="rand_augment_transform"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.auto_augment.rand_augment_transform">[docs]</a><span class="k">def</span> <span class="nf">rand_augment_transform</span><span class="p">(</span><span class="n">config_str</span><span class="p">,</span> <span class="n">hparams</span><span class="p">):</span>
  386. <span class="sd">&quot;&quot;&quot;</span>
  387. <span class="sd"> Create a RandAugment transform</span>
  388. <span class="sd"> :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by</span>
  389. <span class="sd"> dashes (&#39;-&#39;). The first section defines the specific variant of rand augment (currently only &#39;rand&#39;). The remaining</span>
  390. <span class="sd"> sections, not order sepecific determine</span>
  391. <span class="sd"> &#39;m&#39; - integer magnitude of rand augment</span>
  392. <span class="sd"> &#39;n&#39; - integer num layers (number of transform ops selected per image)</span>
  393. <span class="sd"> &#39;w&#39; - integer probabiliy weight index (index of a set of weights to influence choice of op)</span>
  394. <span class="sd"> &#39;mstd&#39; - float std deviation of magnitude noise applied</span>
  395. <span class="sd"> &#39;inc&#39; - integer (bool), use augmentations that increase in severity with magnitude (default: 0)</span>
  396. <span class="sd"> Ex &#39;rand-m9-n3-mstd0.5&#39; results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5</span>
  397. <span class="sd"> &#39;rand-mstd1-w0&#39; results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2</span>
  398. <span class="sd"> :param hparams: Other hparams (kwargs) for the RandAugmentation scheme</span>
  399. <span class="sd"> :return: A PyTorch compatible Transform</span>
  400. <span class="sd"> &quot;&quot;&quot;</span>
  401. <span class="n">magnitude</span> <span class="o">=</span> <span class="n">_MAX_MAGNITUDE</span> <span class="c1"># default to _MAX_MAGNITUDE for magnitude (currently 10)</span>
  402. <span class="n">num_layers</span> <span class="o">=</span> <span class="mi">2</span> <span class="c1"># default to 2 ops per image</span>
  403. <span class="n">weight_idx</span> <span class="o">=</span> <span class="kc">None</span> <span class="c1"># default to no probability weights for op choice</span>
  404. <span class="n">transforms</span> <span class="o">=</span> <span class="n">_RAND_TRANSFORMS</span>
  405. <span class="n">config</span> <span class="o">=</span> <span class="n">config_str</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">&#39;-&#39;</span><span class="p">)</span>
  406. <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="n">config</span><span class="p">:</span>
  407. <span class="n">cs</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="sa">r</span><span class="s1">&#39;(\d.*)&#39;</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span>
  408. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">cs</span><span class="p">)</span> <span class="o">&lt;</span> <span class="mi">2</span><span class="p">:</span>
  409. <span class="k">continue</span>
  410. <span class="n">key</span><span class="p">,</span> <span class="n">val</span> <span class="o">=</span> <span class="n">cs</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span>
  411. <span class="k">if</span> <span class="n">key</span> <span class="o">==</span> <span class="s1">&#39;mstd&#39;</span><span class="p">:</span>
  412. <span class="c1"># noise param injected via hparams for now</span>
  413. <span class="n">hparams</span><span class="o">.</span><span class="n">setdefault</span><span class="p">(</span><span class="s1">&#39;magnitude_std&#39;</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="n">val</span><span class="p">))</span>
  414. <span class="k">elif</span> <span class="n">key</span> <span class="o">==</span> <span class="s1">&#39;inc&#39;</span><span class="p">:</span>
  415. <span class="k">if</span> <span class="nb">bool</span><span class="p">(</span><span class="n">val</span><span class="p">):</span>
  416. <span class="n">transforms</span> <span class="o">=</span> <span class="n">_RAND_INCREASING_TRANSFORMS</span>
  417. <span class="k">elif</span> <span class="n">key</span> <span class="o">==</span> <span class="s1">&#39;m&#39;</span><span class="p">:</span>
  418. <span class="n">magnitude</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">val</span><span class="p">)</span>
  419. <span class="k">elif</span> <span class="n">key</span> <span class="o">==</span> <span class="s1">&#39;n&#39;</span><span class="p">:</span>
  420. <span class="n">num_layers</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">val</span><span class="p">)</span>
  421. <span class="k">elif</span> <span class="n">key</span> <span class="o">==</span> <span class="s1">&#39;w&#39;</span><span class="p">:</span>
  422. <span class="n">weight_idx</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">val</span><span class="p">)</span>
  423. <span class="k">else</span><span class="p">:</span>
  424. <span class="k">assert</span> <span class="kc">False</span><span class="p">,</span> <span class="s1">&#39;Unknown RandAugment config section&#39;</span>
  425. <span class="n">ra_ops</span> <span class="o">=</span> <span class="n">rand_augment_ops</span><span class="p">(</span><span class="n">magnitude</span><span class="o">=</span><span class="n">magnitude</span><span class="p">,</span> <span class="n">hparams</span><span class="o">=</span><span class="n">hparams</span><span class="p">,</span> <span class="n">transforms</span><span class="o">=</span><span class="n">transforms</span><span class="p">)</span>
  426. <span class="n">choice_weights</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">if</span> <span class="n">weight_idx</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">_select_rand_weights</span><span class="p">(</span><span class="n">weight_idx</span><span class="p">)</span>
  427. <span class="k">return</span> <span class="n">RandAugment</span><span class="p">(</span><span class="n">ra_ops</span><span class="p">,</span> <span class="n">num_layers</span><span class="p">,</span> <span class="n">choice_weights</span><span class="o">=</span><span class="n">choice_weights</span><span class="p">)</span></div>
  428. </pre></div>
  429. </div>
  430. </div>
  431. <footer>
  432. <hr/>
  433. <div role="contentinfo">
  434. <p>&#169; Copyright 2021, SuperGradients team.</p>
  435. </div>
  436. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  437. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  438. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  439. </footer>
  440. </div>
  441. </div>
  442. </section>
  443. </div>
  444. <script>
  445. jQuery(function () {
  446. SphinxRtdTheme.Navigation.enable(true);
  447. });
  448. </script>
  449. </body>
  450. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.datasets.classification_datasets.cifar &mdash; SuperGradients 3.0.3 documentation</title>
  7. <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
  10. <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
  11. <!--[if lt IE 9]>
  12. <script src="../../../../../_static/js/html5shiv.min.js"></script>
  13. <![endif]-->
  14. <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
  15. <script src="../../../../../_static/jquery.js"></script>
  16. <script src="../../../../../_static/underscore.js"></script>
  17. <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
  18. <script src="../../../../../_static/doctools.js"></script>
  19. <script src="../../../../../_static/sphinx_highlight.js"></script>
  20. <script src="../../../../../_static/js/theme.js"></script>
  21. <link rel="index" title="Index" href="../../../../../genindex.html" />
  22. <link rel="search" title="Search" href="../../../../../search.html" />
  23. </head>
  24. <body class="wy-body-for-nav">
  25. <div class="wy-grid-for-nav">
  26. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  27. <div class="wy-side-scroll">
  28. <div class="wy-side-nav-search" >
  29. <a href="../../../../../index.html" class="icon icon-home"> SuperGradients
  30. </a>
  31. <div role="search">
  32. <form id="rtd-search-form" class="wy-form" action="../../../../../search.html" method="get">
  33. <input type="text" name="q" placeholder="Search docs" />
  34. <input type="hidden" name="check_keywords" value="yes" />
  35. <input type="hidden" name="area" value="default" />
  36. </form>
  37. </div>
  38. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  39. <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
  40. <ul>
  41. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
  42. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
  45. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
  46. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
  47. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
  57. </ul>
  58. <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
  59. <ul>
  60. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
  61. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
  62. </ul>
  63. </div>
  64. </div>
  65. </nav>
  66. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  67. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  68. <a href="../../../../../index.html">SuperGradients</a>
  69. </nav>
  70. <div class="wy-nav-content">
  71. <div class="rst-content">
  72. <div role="navigation" aria-label="Page navigation">
  73. <ul class="wy-breadcrumbs">
  74. <li><a href="../../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  75. <li><a href="../../../../index.html">Module code</a> &raquo;</li>
  76. <li>super_gradients.training.datasets.classification_datasets.cifar</li>
  77. <li class="wy-breadcrumbs-aside">
  78. </li>
  79. </ul>
  80. <hr/>
  81. </div>
  82. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  83. <div itemprop="articleBody">
  84. <h1>Source code for super_gradients.training.datasets.classification_datasets.cifar</h1><div class="highlight"><pre>
  85. <span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">Union</span>
  86. <span class="kn">from</span> <span class="nn">torchvision.transforms</span> <span class="kn">import</span> <span class="n">Compose</span>
  87. <span class="kn">from</span> <span class="nn">super_gradients.common.factories.transforms_factory</span> <span class="kn">import</span> <span class="n">TransformsFactory</span>
  88. <span class="kn">from</span> <span class="nn">super_gradients.common.decorators.factory_decorator</span> <span class="kn">import</span> <span class="n">resolve_param</span>
  89. <span class="kn">from</span> <span class="nn">torchvision.datasets</span> <span class="kn">import</span> <span class="n">CIFAR10</span><span class="p">,</span> <span class="n">CIFAR100</span>
  90. <div class="viewcode-block" id="Cifar10"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.Cifar10">[docs]</a><span class="k">class</span> <span class="nc">Cifar10</span><span class="p">(</span><span class="n">CIFAR10</span><span class="p">):</span>
  91. <span class="sd">&quot;&quot;&quot;</span>
  92. <span class="sd"> CIFAR10 Dataset</span>
  93. <span class="sd"> :param root: Path for the data to be extracted</span>
  94. <span class="sd"> :param train: Bool to load training (True) or validation (False) part of the dataset</span>
  95. <span class="sd"> :param transforms: List of transforms to apply sequentially on sample. Wrapped internally with torchvision.Compose</span>
  96. <span class="sd"> :param target_transform: Transform to apply to target output</span>
  97. <span class="sd"> :param download: Download (True) the dataset from source</span>
  98. <span class="sd"> &quot;&quot;&quot;</span>
  99. <span class="nd">@resolve_param</span><span class="p">(</span><span class="s2">&quot;transforms&quot;</span><span class="p">,</span> <span class="n">TransformsFactory</span><span class="p">())</span>
  100. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
  101. <span class="bp">self</span><span class="p">,</span>
  102. <span class="n">root</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
  103. <span class="n">train</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
  104. <span class="n">transforms</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">list</span><span class="p">,</span> <span class="nb">dict</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
  105. <span class="n">target_transform</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Callable</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
  106. <span class="n">download</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
  107. <span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
  108. <span class="c1"># TO KEEP BACKWARD COMPATABILITY, WILL BE REMOVED IN THE FUTURE ONCE WE ALLIGN TORCHVISION/NATIVE TRANSFORMS</span>
  109. <span class="c1"># TREATMENT IN FACTORIES (I.E STATING COMPOSE IN CONFIGS)</span>
  110. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">transforms</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
  111. <span class="n">transforms</span> <span class="o">=</span> <span class="n">Compose</span><span class="p">(</span><span class="n">transforms</span><span class="p">)</span>
  112. <span class="nb">super</span><span class="p">(</span><span class="n">Cifar10</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
  113. <span class="n">root</span><span class="o">=</span><span class="n">root</span><span class="p">,</span>
  114. <span class="n">train</span><span class="o">=</span><span class="n">train</span><span class="p">,</span>
  115. <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="p">,</span>
  116. <span class="n">target_transform</span><span class="o">=</span><span class="n">target_transform</span><span class="p">,</span>
  117. <span class="n">download</span><span class="o">=</span><span class="n">download</span><span class="p">,</span>
  118. <span class="p">)</span></div>
  119. <div class="viewcode-block" id="Cifar100"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.Cifar100">[docs]</a><span class="k">class</span> <span class="nc">Cifar100</span><span class="p">(</span><span class="n">CIFAR100</span><span class="p">):</span>
  120. <span class="nd">@resolve_param</span><span class="p">(</span><span class="s2">&quot;transforms&quot;</span><span class="p">,</span> <span class="n">TransformsFactory</span><span class="p">())</span>
  121. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
  122. <span class="bp">self</span><span class="p">,</span>
  123. <span class="n">root</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
  124. <span class="n">train</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
  125. <span class="n">transforms</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">list</span><span class="p">,</span> <span class="nb">dict</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
  126. <span class="n">target_transform</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Callable</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
  127. <span class="n">download</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
  128. <span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
  129. <span class="sd">&quot;&quot;&quot;</span>
  130. <span class="sd"> CIFAR100 Dataset</span>
  131. <span class="sd"> :param root: Path for the data to be extracted</span>
  132. <span class="sd"> :param train: Bool to load training (True) or validation (False) part of the dataset</span>
  133. <span class="sd"> :param transforms: List of transforms to apply sequentially on sample. Wrapped internally with torchvision.Compose</span>
  134. <span class="sd"> :param target_transform: Transform to apply to target output</span>
  135. <span class="sd"> :param download: Download (True) the dataset from source</span>
  136. <span class="sd"> &quot;&quot;&quot;</span>
  137. <span class="c1"># TO KEEP BACKWARD COMPATABILITY, WILL BE REMOVED IN THE FUTURE ONCE WE ALLIGN TORCHVISION/NATIVE TRANSFORMS</span>
  138. <span class="c1"># TREATMENT IN FACTORIES (I.E STATING COMPOSE IN CONFIGS)</span>
  139. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">transforms</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
  140. <span class="n">transforms</span> <span class="o">=</span> <span class="n">Compose</span><span class="p">(</span><span class="n">transforms</span><span class="p">)</span>
  141. <span class="nb">super</span><span class="p">(</span><span class="n">Cifar100</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
  142. <span class="n">root</span><span class="o">=</span><span class="n">root</span><span class="p">,</span>
  143. <span class="n">train</span><span class="o">=</span><span class="n">train</span><span class="p">,</span>
  144. <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="p">,</span>
  145. <span class="n">target_transform</span><span class="o">=</span><span class="n">target_transform</span><span class="p">,</span>
  146. <span class="n">download</span><span class="o">=</span><span class="n">download</span><span class="p">,</span>
  147. <span class="p">)</span></div>
  148. </pre></div>
  149. </div>
  150. </div>
  151. <footer>
  152. <hr/>
  153. <div role="contentinfo">
  154. <p>&#169; Copyright 2021, SuperGradients team.</p>
  155. </div>
  156. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  157. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  158. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  159. </footer>
  160. </div>
  161. </div>
  162. </section>
  163. </div>
  164. <script>
  165. jQuery(function () {
  166. SphinxRtdTheme.Navigation.enable(true);
  167. });
  168. </script>
  169. </body>
  170. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.datasets.classification_datasets.imagenet_dataset &mdash; SuperGradients 3.0.3 documentation</title>
  7. <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
  10. <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
  11. <!--[if lt IE 9]>
  12. <script src="../../../../../_static/js/html5shiv.min.js"></script>
  13. <![endif]-->
  14. <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
  15. <script src="../../../../../_static/jquery.js"></script>
  16. <script src="../../../../../_static/underscore.js"></script>
  17. <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
  18. <script src="../../../../../_static/doctools.js"></script>
  19. <script src="../../../../../_static/sphinx_highlight.js"></script>
  20. <script src="../../../../../_static/js/theme.js"></script>
  21. <link rel="index" title="Index" href="../../../../../genindex.html" />
  22. <link rel="search" title="Search" href="../../../../../search.html" />
  23. </head>
  24. <body class="wy-body-for-nav">
  25. <div class="wy-grid-for-nav">
  26. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  27. <div class="wy-side-scroll">
  28. <div class="wy-side-nav-search" >
  29. <a href="../../../../../index.html" class="icon icon-home"> SuperGradients
  30. </a>
  31. <div role="search">
  32. <form id="rtd-search-form" class="wy-form" action="../../../../../search.html" method="get">
  33. <input type="text" name="q" placeholder="Search docs" />
  34. <input type="hidden" name="check_keywords" value="yes" />
  35. <input type="hidden" name="area" value="default" />
  36. </form>
  37. </div>
  38. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  39. <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
  40. <ul>
  41. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
  42. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
  45. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
  46. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
  47. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
  57. </ul>
  58. <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
  59. <ul>
  60. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
  61. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
  62. </ul>
  63. </div>
  64. </div>
  65. </nav>
  66. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  67. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  68. <a href="../../../../../index.html">SuperGradients</a>
  69. </nav>
  70. <div class="wy-nav-content">
  71. <div class="rst-content">
  72. <div role="navigation" aria-label="Page navigation">
  73. <ul class="wy-breadcrumbs">
  74. <li><a href="../../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  75. <li><a href="../../../../index.html">Module code</a> &raquo;</li>
  76. <li>super_gradients.training.datasets.classification_datasets.imagenet_dataset</li>
  77. <li class="wy-breadcrumbs-aside">
  78. </li>
  79. </ul>
  80. <hr/>
  81. </div>
  82. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  83. <div itemprop="articleBody">
  84. <h1>Source code for super_gradients.training.datasets.classification_datasets.imagenet_dataset</h1><div class="highlight"><pre>
  85. <span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span>
  86. <span class="kn">import</span> <span class="nn">torchvision.datasets</span> <span class="k">as</span> <span class="nn">torch_datasets</span>
  87. <span class="kn">from</span> <span class="nn">torchvision.transforms</span> <span class="kn">import</span> <span class="n">Compose</span>
  88. <span class="kn">from</span> <span class="nn">super_gradients.common.decorators.factory_decorator</span> <span class="kn">import</span> <span class="n">resolve_param</span>
  89. <span class="kn">from</span> <span class="nn">super_gradients.common.factories.transforms_factory</span> <span class="kn">import</span> <span class="n">TransformsFactory</span>
  90. <div class="viewcode-block" id="ImageNetDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.ImageNetDataset">[docs]</a><span class="k">class</span> <span class="nc">ImageNetDataset</span><span class="p">(</span><span class="n">torch_datasets</span><span class="o">.</span><span class="n">ImageFolder</span><span class="p">):</span>
  91. <span class="sd">&quot;&quot;&quot;ImageNetDataset dataset&quot;&quot;&quot;</span>
  92. <span class="nd">@resolve_param</span><span class="p">(</span><span class="s2">&quot;transforms&quot;</span><span class="p">,</span> <span class="n">factory</span><span class="o">=</span><span class="n">TransformsFactory</span><span class="p">())</span>
  93. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">root</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">transforms</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">list</span><span class="p">,</span> <span class="nb">dict</span><span class="p">]</span> <span class="o">=</span> <span class="p">[],</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  94. <span class="c1"># TO KEEP BACKWARD COMPATABILITY, WILL BE REMOVED IN THE FUTURE ONCE WE ALLIGN TORCHVISION/NATIVE TRANSFORMS</span>
  95. <span class="c1"># TREATMENT IN FACTORIES (I.E STATING COMPOSE IN CONFIGS)</span>
  96. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">transforms</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
  97. <span class="n">transforms</span> <span class="o">=</span> <span class="n">Compose</span><span class="p">(</span><span class="n">transforms</span><span class="p">)</span>
  98. <span class="nb">super</span><span class="p">(</span><span class="n">ImageNetDataset</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
  99. </pre></div>
  100. </div>
  101. </div>
  102. <footer>
  103. <hr/>
  104. <div role="contentinfo">
  105. <p>&#169; Copyright 2021, SuperGradients team.</p>
  106. </div>
  107. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  108. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  109. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  110. </footer>
  111. </div>
  112. </div>
  113. </section>
  114. </div>
  115. <script>
  116. jQuery(function () {
  117. SphinxRtdTheme.Navigation.enable(true);
  118. });
  119. </script>
  120. </body>
  121. </html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.datasets.data_augmentation &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.datasets.data_augmentation &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -91,9 +93,9 @@
 <span class="kn">from</span> <span class="nn">torchvision.transforms</span> <span class="kn">import</span> <span class="n">RandomErasing</span>
 <span class="kn">from</span> <span class="nn">torchvision.transforms</span> <span class="kn">import</span> <span class="n">RandomErasing</span>
 
 
 
 
-<div class="viewcode-block" id="DataAugmentation"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.DataAugmentation">[docs]</a><span class="k">class</span> <span class="nc">DataAugmentation</span><span class="p">:</span>
+<div class="viewcode-block" id="DataAugmentation"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.DataAugmentation">[docs]</a><span class="k">class</span> <span class="nc">DataAugmentation</span><span class="p">:</span>
 
 
-<div class="viewcode-block" id="DataAugmentation.to_tensor"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.DataAugmentation.to_tensor">[docs]</a>    <span class="nd">@staticmethod</span>
+<div class="viewcode-block" id="DataAugmentation.to_tensor"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.DataAugmentation.to_tensor">[docs]</a>    <span class="nd">@staticmethod</span>
     <span class="k">def</span> <span class="nf">to_tensor</span><span class="p">():</span>
     <span class="k">def</span> <span class="nf">to_tensor</span><span class="p">():</span>
         <span class="k">def</span> <span class="nf">_to_tensor</span><span class="p">(</span><span class="n">image</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">_to_tensor</span><span class="p">(</span><span class="n">image</span><span class="p">):</span>
             <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">image</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
             <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">image</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
@@ -104,7 +106,7 @@
 
 
         <span class="k">return</span> <span class="n">_to_tensor</span></div>
         <span class="k">return</span> <span class="n">_to_tensor</span></div>
 
 
-<div class="viewcode-block" id="DataAugmentation.normalize"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.DataAugmentation.normalize">[docs]</a>    <span class="nd">@staticmethod</span>
+<div class="viewcode-block" id="DataAugmentation.normalize"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.DataAugmentation.normalize">[docs]</a>    <span class="nd">@staticmethod</span>
     <span class="k">def</span> <span class="nf">normalize</span><span class="p">(</span><span class="n">mean</span><span class="p">,</span> <span class="n">std</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">normalize</span><span class="p">(</span><span class="n">mean</span><span class="p">,</span> <span class="n">std</span><span class="p">):</span>
         <span class="n">mean</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">mean</span><span class="p">)</span>
         <span class="n">mean</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">mean</span><span class="p">)</span>
         <span class="n">std</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">std</span><span class="p">)</span>
         <span class="n">std</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">std</span><span class="p">)</span>
@@ -116,7 +118,7 @@
 
 
         <span class="k">return</span> <span class="n">_normalize</span></div>
         <span class="k">return</span> <span class="n">_normalize</span></div>
 
 
-<div class="viewcode-block" id="DataAugmentation.cutout"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.DataAugmentation.cutout">[docs]</a>    <span class="nd">@staticmethod</span>
+<div class="viewcode-block" id="DataAugmentation.cutout"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.DataAugmentation.cutout">[docs]</a>    <span class="nd">@staticmethod</span>
     <span class="k">def</span> <span class="nf">cutout</span><span class="p">(</span><span class="n">mask_size</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">cutout_inside</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">mask_color</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="m
     <span class="k">def</span> <span class="nf">cutout</span><span class="p">(</span><span class="n">mask_size</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">cutout_inside</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">mask_color</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="m
         <span class="n">mask_size_half</span> <span class="o">=</span> <span class="n">mask_size</span> <span class="o">//</span> <span class="mi">2</span>
         <span class="n">mask_size_half</span> <span class="o">=</span> <span class="n">mask_size</span> <span class="o">//</span> <span class="mi">2</span>
         <span class="n">offset</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">mask_size</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="mi">0</span>
         <span class="n">offset</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">mask_size</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">else</span> <span class="mi">0</span>
@@ -159,7 +161,7 @@
                             <span class="p">[</span><span class="o">-</span><span class="mf">0.5836</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.6948</span><span class="p">,</span> <span class="mf">0.4203</span><span class="p">]])}</span>
                             <span class="p">[</span><span class="o">-</span><span class="mf">0.5836</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.6948</span><span class="p">,</span> <span class="mf">0.4203</span><span class="p">]])}</span>
 
 
 
 
-<div class="viewcode-block" id="Lighting"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.Lighting">[docs]</a><span class="k">class</span> <span class="nc">Lighting</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
+<span class="k">class</span> <span class="nc">Lighting</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Lighting noise(AlexNet - style PCA - based noise)</span>
 <span class="sd">    Lighting noise(AlexNet - style PCA - based noise)</span>
 <span class="sd">    Taken from fastai Imagenet training -</span>
 <span class="sd">    Taken from fastai Imagenet training -</span>
@@ -183,10 +185,10 @@
             <span class="o">.</span><span class="n">mul</span><span class="p">(</span><span class="n">alpha</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> \
             <span class="o">.</span><span class="n">mul</span><span class="p">(</span><span class="n">alpha</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> \
             <span class="o">.</span><span class="n">mul</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">eigval</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span 
             <span class="o">.</span><span class="n">mul</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">eigval</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span 
             <span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
             <span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
-        <span class="k">return</span> <span class="n">img</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">rgb</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">expand_as</span><span class="p">(</span><span class="n">img</span><span
+        <span class="k">return</span> <span class="n">img</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">rgb</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">expand_as</span><span class="p">(</span><span class="n">img</span><span
 
 
 
 
-<div class="viewcode-block" id="RandomErase"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.RandomErase">[docs]</a><span class="k">class</span> <span class="nc">RandomErase</span><span class="p">(</span><span class="n">RandomErasing</span><span class="p">):</span>
+<span class="k">class</span> <span class="nc">RandomErase</span><span class="p">(</span><span class="n">RandomErasing</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    A simple class that translates the parameters supported in SuperGradient&#39;s code base</span>
 <span class="sd">    A simple class that translates the parameters supported in SuperGradient&#39;s code base</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
@@ -197,7 +199,7 @@
             <span class="n">value</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="n">value</span><span class="p">)</span>
             <span class="n">value</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="n">value</span><span class="p">)</span>
         <span class="k">except</span> <span class="ne">ValueError</span><span class="p">:</span>
         <span class="k">except</span> <span class="ne">ValueError</span><span class="p">:</span>
             <span class="k">pass</span>
             <span class="k">pass</span>
-        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">p</span><span class="o">=</span><span class="n">probability</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">p</span><span class="o">=</span><span class="n">probability</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">value</span><span class="p">)</span>
 </pre></div>
 </pre></div>
 
 
            </div>
            </div>
@@ -227,4 +229,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.datasets.dataset_interfaces.dataset_interface &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../../_static/jquery.js"></script>
  15. <script src="../../../../../_static/underscore.js"></script>
  16. <script src="../../../../../_static/doctools.js"></script>
  17. <script src="../../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.datasets.dataset_interfaces.dataset_interface</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.datasets.dataset_interfaces.dataset_interface</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">os</span>
  84. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  85. <span class="kn">import</span> <span class="nn">torch</span>
  86. <span class="kn">import</span> <span class="nn">torchvision</span>
  87. <span class="kn">import</span> <span class="nn">torchvision.datasets</span> <span class="k">as</span> <span class="nn">datasets</span>
  88. <span class="kn">from</span> <span class="nn">torch.utils.data.distributed</span> <span class="kn">import</span> <span class="n">DistributedSampler</span>
  89. <span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">ConcatDataset</span><span class="p">,</span> <span class="n">BatchSampler</span><span class="p">,</span> <span class="n">DataLoader</span>
  90. <span class="kn">import</span> <span class="nn">torchvision.transforms</span> <span class="k">as</span> <span class="nn">transforms</span>
  91. <span class="kn">from</span> <span class="nn">super_gradients.common</span> <span class="kn">import</span> <span class="n">DatasetDataInterface</span>
  92. <span class="kn">from</span> <span class="nn">super_gradients.common.environment</span> <span class="kn">import</span> <span class="n">AWS_ENV_NAME</span>
  93. <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
  94. <span class="kn">from</span> <span class="nn">super_gradients.training</span> <span class="kn">import</span> <span class="n">utils</span> <span class="k">as</span> <span class="n">core_utils</span>
  95. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.distributed_training_utils</span> <span class="kn">import</span> <span class="n">get_local_rank</span><span class="p">,</span> <span class="n">wait_for_the_master</span>
  96. <span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">get_param</span>
  97. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">DetectionTargetsFormat</span>
  98. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets</span> <span class="kn">import</span> <span class="n">datasets_utils</span><span class="p">,</span> <span class="n">DataAugmentation</span>
  99. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.datasets_conf</span> <span class="kn">import</span> <span class="n">COCO_DETECTION_CLASSES_LIST</span>
  100. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.data_augmentation</span> <span class="kn">import</span> <span class="n">Lighting</span><span class="p">,</span> <span class="n">RandomErase</span>
  101. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.mixup</span> <span class="kn">import</span> <span class="n">CollateMixup</span>
  102. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.detection_datasets</span> <span class="kn">import</span> <span class="n">COCODetectionDataset</span><span class="p">,</span> <span class="n">PascalVOCDetectionDataset</span>
  103. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.samplers.infinite_sampler</span> <span class="kn">import</span> <span class="n">InfiniteSampler</span>
  104. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets</span> <span class="kn">import</span> <span class="n">PascalVOC2012SegmentationDataSet</span><span class="p">,</span> \
  105. <span class="n">PascalAUG2012SegmentationDataSet</span><span class="p">,</span> <span class="n">CoCoSegmentationDataSet</span>
  106. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation</span> <span class="kn">import</span> <span class="n">CityscapesDataset</span>
  107. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets.supervisely_persons_segmentation</span> <span class="kn">import</span> \
  108. <span class="n">SuperviselyPersonsDataset</span>
  109. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.samplers.repeated_augmentation_sampler</span> <span class="kn">import</span> <span class="n">RepeatAugSampler</span>
  110. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.datasets_utils</span> <span class="kn">import</span> <span class="n">RandomResizedCropAndInterpolation</span><span class="p">,</span> <span class="n">worker_init_reset_seed</span>
  111. <span class="kn">from</span> <span class="nn">super_gradients.training.transforms.transforms</span> <span class="kn">import</span> <span class="n">DetectionMosaic</span><span class="p">,</span> <span class="n">DetectionMixup</span><span class="p">,</span> <span class="n">DetectionRandomAffine</span><span class="p">,</span>\
  112. <span class="n">DetectionTargetsFormatTransform</span><span class="p">,</span> <span class="n">DetectionPaddedRescale</span><span class="p">,</span> <span class="n">DetectionHSV</span><span class="p">,</span> <span class="n">DetectionHorizontalFlip</span>
  113. <span class="kn">from</span> <span class="nn">super_gradients.training.exceptions.dataset_exceptions</span> <span class="kn">import</span> <span class="n">IllegalDatasetParameterException</span>
  114. <span class="n">default_dataset_params</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;batch_size&quot;</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s2">&quot;val_batch_size&quot;</span><span class="p">:</span> <span class="mi">200</span><span class="p">,</span> <span class="s2">&quot;test_batch_size&quot;</span><span class="p">:</span> <span class="mi">200</span><span class="p">,</span> <span class="s2">&quot;dataset_dir&quot;</span><span class="p">:</span> <span class="s2">&quot;./data/&quot;</span><span class="p">,</span>
  115. <span class="s2">&quot;s3_link&quot;</span><span class="p">:</span> <span class="kc">None</span><span class="p">}</span>
  116. <span class="n">LIBRARY_DATASETS</span> <span class="o">=</span> <span class="p">{</span>
  117. <span class="s2">&quot;cifar10&quot;</span><span class="p">:</span> <span class="p">{</span><span class="s1">&#39;class&#39;</span><span class="p">:</span> <span class="n">datasets</span><span class="o">.</span><span class="n">CIFAR10</span><span class="p">,</span> <span class="s1">&#39;mean&#39;</span><span class="p">:</span> <span class="p">(</span><span class="mf">0.4914</span><span class="p">,</span> <span class="mf">0.4822</span><span class="p">,</span> <span class="mf">0.4465</span><span class="p">),</span> <span class="s1">&#39;std&#39;</span><span class="p">:</span> <span class="p">(</span><span class="mf">0.2023</span><span class="p">,</span> <span class="mf">0.1994</span><span class="p">,</span> <span class="mf">0.2010</span><span class="p">)},</span>
  118. <span class="s2">&quot;cifar100&quot;</span><span class="p">:</span> <span class="p">{</span><span class="s1">&#39;class&#39;</span><span class="p">:</span> <span class="n">datasets</span><span class="o">.</span><span class="n">CIFAR100</span><span class="p">,</span> <span class="s1">&#39;mean&#39;</span><span class="p">:</span> <span class="p">(</span><span class="mf">0.5071</span><span class="p">,</span> <span class="mf">0.4865</span><span class="p">,</span> <span class="mf">0.4409</span><span class="p">),</span> <span class="s1">&#39;std&#39;</span><span class="p">:</span> <span class="p">(</span><span class="mf">0.2673</span><span class="p">,</span> <span class="mf">0.2564</span><span class="p">,</span> <span class="mf">0.2762</span><span class="p">)},</span>
  119. <span class="s2">&quot;SVHN&quot;</span><span class="p">:</span> <span class="p">{</span><span class="s1">&#39;class&#39;</span><span class="p">:</span> <span class="n">datasets</span><span class="o">.</span><span class="n">SVHN</span><span class="p">,</span> <span class="s1">&#39;mean&#39;</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span> <span class="s1">&#39;std&#39;</span><span class="p">:</span> <span class="kc">None</span><span class="p">}</span>
  120. <span class="p">}</span>
  121. <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
  122. <div class="viewcode-block" id="DatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.DatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">DatasetInterface</span><span class="p">:</span>
  123. <span class="sd">&quot;&quot;&quot;</span>
  124. <span class="sd"> DatasetInterface - This class manages all of the &quot;communiation&quot; the Model has with the Data Sets</span>
  125. <span class="sd"> &quot;&quot;&quot;</span>
  126. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">train_loader</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">val_loader</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">test_loader</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">classes</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  127. <span class="sd">&quot;&quot;&quot;</span>
  128. <span class="sd"> @param train_loader: torch.utils.data.Dataloader (optional) dataloader for training.</span>
  129. <span class="sd"> @param test_loader: torch.utils.data.Dataloader (optional) dataloader for testing.</span>
  130. <span class="sd"> @param classes: list of classes.</span>
  131. <span class="sd"> Note: the above parameters will be discarded in case dataset_params is passed.</span>
  132. <span class="sd"> @param dataset_params:</span>
  133. <span class="sd"> - `batch_size` : int (default=64)</span>
  134. <span class="sd"> Number of examples per batch for training. Large batch sizes are recommended.</span>
  135. <span class="sd"> - `val_batch_size` : int (default=200)</span>
  136. <span class="sd"> Number of examples per batch for validation. Large batch sizes are recommended.</span>
  137. <span class="sd"> - `dataset_dir` : str (default=&quot;./data/&quot;)</span>
  138. <span class="sd"> Directory location for the data. Data will be downloaded to this directory when getting it from a</span>
  139. <span class="sd"> remote url.</span>
  140. <span class="sd"> - `s3_link` : str (default=None)</span>
  141. <span class="sd"> remote s3 link to download the data (optional).</span>
  142. <span class="sd"> - `aug_repeat_count` : int (default=0)</span>
  143. <span class="sd"> amount of repetitions (each repetition of an example is augmented differently) for each</span>
  144. <span class="sd"> example for the trainset.</span>
  145. <span class="sd"> &quot;&quot;&quot;</span>
  146. <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">HpmStruct</span><span class="p">(</span><span class="o">**</span><span class="n">default_dataset_params</span><span class="p">)</span>
  147. <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">override</span><span class="p">(</span><span class="o">**</span><span class="n">dataset_params</span><span class="p">)</span>
  148. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">testset</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span>
  149. <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">val_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">test_loader</span> <span class="o">=</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">val_loader</span><span class="p">,</span> <span class="n">test_loader</span>
  150. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">classes</span>
  151. <span class="bp">self</span><span class="o">.</span><span class="n">batch_size_factor</span> <span class="o">=</span> <span class="mi">1</span>
  152. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">s3_link</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  153. <span class="bp">self</span><span class="o">.</span><span class="n">download_from_cloud</span><span class="p">()</span>
  154. <div class="viewcode-block" id="DatasetInterface.download_from_cloud"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.DatasetInterface.download_from_cloud">[docs]</a> <span class="k">def</span> <span class="nf">download_from_cloud</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  155. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">s3_link</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  156. <span class="n">env_name</span> <span class="o">=</span> <span class="n">AWS_ENV_NAME</span>
  157. <span class="n">downloader</span> <span class="o">=</span> <span class="n">DatasetDataInterface</span><span class="p">(</span><span class="n">env</span><span class="o">=</span><span class="n">env_name</span><span class="p">)</span>
  158. <span class="n">target_dir</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">dataset_dir</span>
  159. <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">target_dir</span><span class="p">):</span>
  160. <span class="n">os</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">target_dir</span><span class="p">)</span>
  161. <span class="n">downloader</span><span class="o">.</span><span class="n">load_remote_dataset_file</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">s3_link</span><span class="p">,</span> <span class="n">target_dir</span><span class="p">)</span></div>
  162. <div class="viewcode-block" id="DatasetInterface.build_data_loaders"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.DatasetInterface.build_data_loaders">[docs]</a> <span class="k">def</span> <span class="nf">build_data_loaders</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size_factor</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">train_batch_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">val_batch_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
  163. <span class="n">test_batch_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">distributed_sampler</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
  164. <span class="sd">&quot;&quot;&quot;</span>
  165. <span class="sd"> define train, val (and optionally test) loaders. The method deals separately with distributed training and standard</span>
  166. <span class="sd"> (non distributed, or parallel training). In the case of distributed training we need to rely on distributed</span>
  167. <span class="sd"> samplers.</span>
  168. <span class="sd"> :param batch_size_factor: int - factor to multiply the batch size (usually for multi gpu)</span>
  169. <span class="sd"> :param num_workers: int - number of workers (parallel processes) for dataloaders</span>
  170. <span class="sd"> :param train_batch_size: int - batch size for train loader, if None will be taken from dataset_params</span>
  171. <span class="sd"> :param val_batch_size: int - batch size for val loader, if None will be taken from dataset_params</span>
  172. <span class="sd"> :param distributed_sampler: boolean flag for distributed training mode</span>
  173. <span class="sd"> :return: train_loader, val_loader, classes: list of classes</span>
  174. <span class="sd"> &quot;&quot;&quot;</span>
  175. <span class="c1"># CHANGE THE BATCH SIZE ACCORDING TO THE NUMBER OF DEVICES - ONLY IN NON-DISTRIBUTED TRAINING MODE</span>
  176. <span class="c1"># IN DISTRIBUTED MODE WE NEED DISTRIBUTED SAMPLERS</span>
  177. <span class="c1"># NO SHUFFLE IN DISTRIBUTED TRAINING</span>
  178. <span class="n">aug_repeat_count</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;aug_repeat_count&quot;</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
  179. <span class="k">if</span> <span class="n">aug_repeat_count</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">distributed_sampler</span><span class="p">:</span>
  180. <span class="k">raise</span> <span class="n">IllegalDatasetParameterException</span><span class="p">(</span><span class="s2">&quot;repeated augmentation is only supported with DDP.&quot;</span><span class="p">)</span>
  181. <span class="k">if</span> <span class="n">distributed_sampler</span><span class="p">:</span>
  182. <span class="bp">self</span><span class="o">.</span><span class="n">batch_size_factor</span> <span class="o">=</span> <span class="mi">1</span>
  183. <span class="n">train_sampler</span> <span class="o">=</span> <span class="n">RepeatAugSampler</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="p">,</span>
  184. <span class="n">num_repeats</span><span class="o">=</span><span class="n">aug_repeat_count</span><span class="p">)</span> <span class="k">if</span> <span class="n">aug_repeat_count</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">DistributedSampler</span><span class="p">(</span>
  185. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="p">)</span>
  186. <span class="n">val_sampler</span> <span class="o">=</span> <span class="n">DistributedSampler</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">)</span>
  187. <span class="n">test_sampler</span> <span class="o">=</span> <span class="n">DistributedSampler</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">testset</span><span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">testset</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="kc">None</span>
  188. <span class="n">train_shuffle</span> <span class="o">=</span> <span class="kc">False</span>
  189. <span class="k">else</span><span class="p">:</span>
  190. <span class="bp">self</span><span class="o">.</span><span class="n">batch_size_factor</span> <span class="o">=</span> <span class="n">batch_size_factor</span>
  191. <span class="n">train_sampler</span> <span class="o">=</span> <span class="kc">None</span>
  192. <span class="n">val_sampler</span> <span class="o">=</span> <span class="kc">None</span>
  193. <span class="n">test_sampler</span> <span class="o">=</span> <span class="kc">None</span>
  194. <span class="n">train_shuffle</span> <span class="o">=</span> <span class="kc">True</span>
  195. <span class="k">if</span> <span class="n">train_batch_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  196. <span class="n">train_batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size_factor</span>
  197. <span class="k">if</span> <span class="n">val_batch_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  198. <span class="n">val_batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_batch_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size_factor</span>
  199. <span class="k">if</span> <span class="n">test_batch_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  200. <span class="n">test_batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">test_batch_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size_factor</span>
  201. <span class="n">train_loader_drop_last</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;train_loader_drop_last&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
  202. <span class="n">cutmix</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;cutmix&#39;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
  203. <span class="n">cutmix_params</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;cutmix_params&#39;</span><span class="p">)</span>
  204. <span class="c1"># WRAPPING collate_fn</span>
  205. <span class="n">train_collate_fn</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="p">,</span> <span class="s1">&#39;collate_fn&#39;</span><span class="p">)</span>
  206. <span class="n">val_collate_fn</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">,</span> <span class="s1">&#39;collate_fn&#39;</span><span class="p">)</span>
  207. <span class="n">test_collate_fn</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">testset</span><span class="p">,</span> <span class="s1">&#39;collate_fn&#39;</span><span class="p">)</span>
  208. <span class="k">if</span> <span class="n">cutmix</span> <span class="ow">and</span> <span class="n">train_collate_fn</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  209. <span class="k">raise</span> <span class="n">IllegalDatasetParameterException</span><span class="p">(</span><span class="s2">&quot;cutmix and collate function cannot be used together&quot;</span><span class="p">)</span>
  210. <span class="k">if</span> <span class="n">cutmix</span><span class="p">:</span>
  211. <span class="c1"># FIXME - cutmix should be available only in classification dataset. once we make sure all classification</span>
  212. <span class="c1"># datasets inherit from the same super class, we should move cutmix code to that class</span>
  213. <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;Cutmix/mixup was enabled. This feature is currently supported only &quot;</span>
  214. <span class="s2">&quot;for classification datasets.&quot;</span><span class="p">)</span>
  215. <span class="n">train_collate_fn</span> <span class="o">=</span> <span class="n">CollateMixup</span><span class="p">(</span><span class="o">**</span><span class="n">cutmix_params</span><span class="p">)</span>
  216. <span class="c1"># FIXME - UNDERSTAND IF THE num_replicas VARIBALE IS NEEDED</span>
  217. <span class="c1"># train_sampler = DistributedSampler(self.trainset,</span>
  218. <span class="c1"># num_replicas=distributed_gpus_num) if distributed_sampler else None</span>
  219. <span class="c1"># val_sampler = DistributedSampler(self.valset,</span>
  220. <span class="c1"># num_replicas=distributed_gpus_num) if distributed_sampler else None</span>
  221. <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="p">,</span>
  222. <span class="n">batch_size</span><span class="o">=</span><span class="n">train_batch_size</span><span class="p">,</span>
  223. <span class="n">shuffle</span><span class="o">=</span><span class="n">train_shuffle</span><span class="p">,</span>
  224. <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span>
  225. <span class="n">pin_memory</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  226. <span class="n">sampler</span><span class="o">=</span><span class="n">train_sampler</span><span class="p">,</span>
  227. <span class="n">collate_fn</span><span class="o">=</span><span class="n">train_collate_fn</span><span class="p">,</span>
  228. <span class="n">drop_last</span><span class="o">=</span><span class="n">train_loader_drop_last</span><span class="p">)</span>
  229. <span class="bp">self</span><span class="o">.</span><span class="n">val_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">,</span>
  230. <span class="n">batch_size</span><span class="o">=</span><span class="n">val_batch_size</span><span class="p">,</span>
  231. <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  232. <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span>
  233. <span class="n">pin_memory</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  234. <span class="n">sampler</span><span class="o">=</span><span class="n">val_sampler</span><span class="p">,</span>
  235. <span class="n">collate_fn</span><span class="o">=</span><span class="n">val_collate_fn</span><span class="p">)</span>
  236. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">testset</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  237. <span class="bp">self</span><span class="o">.</span><span class="n">test_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">testset</span><span class="p">,</span>
  238. <span class="n">batch_size</span><span class="o">=</span><span class="n">test_batch_size</span><span class="p">,</span>
  239. <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  240. <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span>
  241. <span class="n">pin_memory</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  242. <span class="n">sampler</span><span class="o">=</span><span class="n">test_sampler</span><span class="p">,</span>
  243. <span class="n">collate_fn</span><span class="o">=</span><span class="n">test_collate_fn</span><span class="p">)</span>
  244. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">classes</span></div>
  245. <div class="viewcode-block" id="DatasetInterface.get_data_loaders"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.DatasetInterface.get_data_loaders">[docs]</a> <span class="k">def</span> <span class="nf">get_data_loaders</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  246. <span class="sd">&quot;&quot;&quot;</span>
  247. <span class="sd"> Get self.train_loader, self.val_loader, self.test_loader, self.classes.</span>
  248. <span class="sd"> If the data loaders haven&#39;t been initialized yet, build them first.</span>
  249. <span class="sd"> :param kwargs: kwargs are passed to build_data_loaders.</span>
  250. <span class="sd"> &quot;&quot;&quot;</span>
  251. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">val_loader</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  252. <span class="bp">self</span><span class="o">.</span><span class="n">build_data_loaders</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  253. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">val_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">test_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">classes</span></div>
  254. <div class="viewcode-block" id="DatasetInterface.get_val_sample"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.DatasetInterface.get_val_sample">[docs]</a> <span class="k">def</span> <span class="nf">get_val_sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_samples</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
  255. <span class="k">if</span> <span class="n">num_samples</span> <span class="o">&gt;</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">):</span>
  256. <span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s2">&quot;Tried to load more samples than val-set size&quot;</span><span class="p">)</span>
  257. <span class="k">if</span> <span class="n">num_samples</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
  258. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
  259. <span class="k">else</span><span class="p">:</span>
  260. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="n">num_samples</span><span class="p">]</span></div>
  261. <div class="viewcode-block" id="DatasetInterface.get_dataset_params"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.DatasetInterface.get_dataset_params">[docs]</a> <span class="k">def</span> <span class="nf">get_dataset_params</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  262. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span></div>
  263. <div class="viewcode-block" id="DatasetInterface.print_dataset_details"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.DatasetInterface.print_dataset_details">[docs]</a> <span class="k">def</span> <span class="nf">print_dataset_details</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  264. <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;</span><span class="si">{}</span><span class="s2"> training samples, </span><span class="si">{}</span><span class="s2"> val samples, </span><span class="si">{}</span><span class="s2"> classes&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">),</span>
  265. <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">classes</span><span class="p">)))</span></div></div>
  266. <div class="viewcode-block" id="ExternalDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.ExternalDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">ExternalDatasetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
  267. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">val_loader</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{}):</span>
  268. <span class="sd">&quot;&quot;&quot;</span>
  269. <span class="sd"> ExternalDatasetInterface - A wrapper for external dataset interface that gets dataloaders from keras/TF</span>
  270. <span class="sd"> and converts them to Torch-like dataloaders that return torch.Tensors after</span>
  271. <span class="sd"> optional collate_fn while maintaining the same interface (connect_dataset_interface etc.)</span>
  272. <span class="sd"> :train_loader: The external train_loader</span>
  273. <span class="sd"> :val_loader: The external val_loader</span>
  274. <span class="sd"> :num_classes: The number of classes</span>
  275. <span class="sd"> :dataset_params The dict that includes the batch_size and/or the collate_fn</span>
  276. <span class="sd"> :return: DataLoaders that generate torch.Tensors batches after collate_fn</span>
  277. <span class="sd"> &quot;&quot;&quot;</span>
  278. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">)</span>
  279. <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span> <span class="o">=</span> <span class="n">train_loader</span>
  280. <span class="bp">self</span><span class="o">.</span><span class="n">val_loader</span> <span class="o">=</span> <span class="n">val_loader</span>
  281. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">num_classes</span>
  282. <div class="viewcode-block" id="ExternalDatasetInterface.get_data_loaders"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.ExternalDatasetInterface.get_data_loaders">[docs]</a> <span class="k">def</span> <span class="nf">get_data_loaders</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size_factor</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> <span class="n">num_workers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">,</span> <span class="n">train_batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
  283. <span class="n">val_batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">distributed_sampler</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
  284. <span class="c1"># CHANGE THE BATCH SIZE ACCORDING TO THE NUMBER OF DEVICES - ONLY IN NON-DISTRIBUED TRAINING MODE</span>
  285. <span class="c1"># IN DISTRIBUTED MODE WE NEED DISTRIBUTED SAMPLERS</span>
  286. <span class="c1"># NO SHUFFLE IN DISTRIBUTED TRAINING</span>
  287. <span class="k">if</span> <span class="n">distributed_sampler</span><span class="p">:</span>
  288. <span class="bp">self</span><span class="o">.</span><span class="n">batch_size_factor</span> <span class="o">=</span> <span class="mi">1</span>
  289. <span class="n">train_sampler</span> <span class="o">=</span> <span class="n">DistributedSampler</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  290. <span class="n">val_sampler</span> <span class="o">=</span> <span class="n">DistributedSampler</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">)</span>
  291. <span class="n">train_shuffle</span> <span class="o">=</span> <span class="kc">False</span>
  292. <span class="k">else</span><span class="p">:</span>
  293. <span class="bp">self</span><span class="o">.</span><span class="n">batch_size_factor</span> <span class="o">=</span> <span class="n">batch_size_factor</span>
  294. <span class="n">train_sampler</span> <span class="o">=</span> <span class="kc">None</span>
  295. <span class="n">val_sampler</span> <span class="o">=</span> <span class="kc">None</span>
  296. <span class="n">train_shuffle</span> <span class="o">=</span> <span class="kc">True</span>
  297. <span class="k">if</span> <span class="n">train_batch_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  298. <span class="n">train_batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size_factor</span>
  299. <span class="k">if</span> <span class="n">val_batch_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  300. <span class="n">val_batch_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_batch_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size_factor</span>
  301. <span class="n">train_loader_drop_last</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;train_loader_drop_last&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
  302. <span class="c1"># WRAPPING collate_fn</span>
  303. <span class="n">train_collate_fn</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;train_collate_fn&#39;</span><span class="p">)</span>
  304. <span class="n">val_collate_fn</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;val_collate_fn&#39;</span><span class="p">)</span>
  305. <span class="c1"># FIXME - UNDERSTAND IF THE num_replicas VARIBALE IS NEEDED</span>
  306. <span class="c1"># train_sampler = DistributedSampler(self.trainset,</span>
  307. <span class="c1"># num_replicas=distributed_gpus_num) if distributed_sampler else None</span>
  308. <span class="c1"># val_sampler = DistributedSampler(self.valset,</span>
  309. <span class="c1"># num_replicas=distributed_gpus_num) if distributed_sampler else None</span>
  310. <span class="bp">self</span><span class="o">.</span><span class="n">torch_train_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">,</span>
  311. <span class="n">batch_size</span><span class="o">=</span><span class="n">train_batch_size</span><span class="p">,</span>
  312. <span class="n">shuffle</span><span class="o">=</span><span class="n">train_shuffle</span><span class="p">,</span>
  313. <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span>
  314. <span class="n">pin_memory</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  315. <span class="n">sampler</span><span class="o">=</span><span class="n">train_sampler</span><span class="p">,</span>
  316. <span class="n">collate_fn</span><span class="o">=</span><span class="n">train_collate_fn</span><span class="p">,</span>
  317. <span class="n">drop_last</span><span class="o">=</span><span class="n">train_loader_drop_last</span><span class="p">)</span>
  318. <span class="bp">self</span><span class="o">.</span><span class="n">torch_val_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">val_loader</span><span class="p">,</span>
  319. <span class="n">batch_size</span><span class="o">=</span><span class="n">val_batch_size</span><span class="p">,</span>
  320. <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  321. <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span>
  322. <span class="n">pin_memory</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  323. <span class="n">sampler</span><span class="o">=</span><span class="n">val_sampler</span><span class="p">,</span>
  324. <span class="n">collate_fn</span><span class="o">=</span><span class="n">val_collate_fn</span><span class="p">)</span>
  325. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">torch_train_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">torch_val_loader</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">classes</span></div></div>
  326. <div class="viewcode-block" id="LibraryDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.LibraryDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">LibraryDatasetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
  327. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;cifar10&quot;</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">to_cutout</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
  328. <span class="nb">super</span><span class="p">(</span><span class="n">LibraryDatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">)</span>
  329. <span class="bp">self</span><span class="o">.</span><span class="n">dataset_name</span> <span class="o">=</span> <span class="n">name</span>
  330. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_name</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">LIBRARY_DATASETS</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
  331. <span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">&#39;dataset not found&#39;</span><span class="p">)</span>
  332. <span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span> <span class="o">=</span> <span class="n">LIBRARY_DATASETS</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_name</span><span class="p">]</span>
  333. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;mean&#39;</span><span class="p">]</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  334. <span class="n">trainset</span> <span class="o">=</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">SVHN</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">dataset_dir</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s1">&#39;train&#39;</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  335. <span class="n">transform</span><span class="o">=</span><span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">())</span>
  336. <span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;mean&#39;</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;std&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">datasets_utils</span><span class="o">.</span><span class="n">get_mean_and_std</span><span class="p">(</span><span class="n">trainset</span><span class="p">)</span>
  337. <span class="c1"># OVERWRITE MEAN AND STD IF DEFINED IN DATASET PARAMS</span>
  338. <span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;mean&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;img_mean&#39;</span><span class="p">,</span>
  339. <span class="n">default_val</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;mean&#39;</span><span class="p">])</span>
  340. <span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;std&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;img_std&#39;</span><span class="p">,</span>
  341. <span class="n">default_val</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;std&#39;</span><span class="p">])</span>
  342. <span class="n">crop_size</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;crop_size&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="mi">32</span><span class="p">)</span>
  343. <span class="k">if</span> <span class="n">to_cutout</span><span class="p">:</span>
  344. <span class="n">transform_train</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
  345. <span class="n">transforms</span><span class="o">.</span><span class="n">RandomCrop</span><span class="p">(</span><span class="n">crop_size</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
  346. <span class="n">transforms</span><span class="o">.</span><span class="n">RandomHorizontalFlip</span><span class="p">(),</span>
  347. <span class="n">DataAugmentation</span><span class="o">.</span><span class="n">normalize</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;mean&#39;</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;std&#39;</span><span class="p">]),</span>
  348. <span class="n">DataAugmentation</span><span class="o">.</span><span class="n">cutout</span><span class="p">(</span><span class="mi">16</span><span class="p">),</span>
  349. <span class="n">DataAugmentation</span><span class="o">.</span><span class="n">to_tensor</span><span class="p">()</span>
  350. <span class="p">])</span>
  351. <span class="k">else</span><span class="p">:</span>
  352. <span class="n">transform_train</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
  353. <span class="n">transforms</span><span class="o">.</span><span class="n">RandomCrop</span><span class="p">(</span><span class="n">crop_size</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span>
  354. <span class="n">transforms</span><span class="o">.</span><span class="n">RandomHorizontalFlip</span><span class="p">(),</span>
  355. <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
  356. <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;mean&#39;</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;std&#39;</span><span class="p">]),</span>
  357. <span class="p">])</span>
  358. <span class="n">transform_val</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
  359. <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
  360. <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;mean&#39;</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s1">&#39;std&#39;</span><span class="p">]),</span>
  361. <span class="p">])</span>
  362. <span class="n">dataset_cls</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lib_dataset_params</span><span class="p">[</span><span class="s2">&quot;class&quot;</span><span class="p">]</span>
  363. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">dataset_cls</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">dataset_dir</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  364. <span class="n">transform</span><span class="o">=</span><span class="n">transform_train</span><span class="p">)</span>
  365. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">dataset_cls</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">dataset_dir</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  366. <span class="n">transform</span><span class="o">=</span><span class="n">transform_val</span><span class="p">)</span></div>
  367. <div class="viewcode-block" id="Cifar10DatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.Cifar10DatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">Cifar10DatasetInterface</span><span class="p">(</span><span class="n">LibraryDatasetInterface</span><span class="p">):</span>
  368. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{}):</span>
  369. <span class="nb">super</span><span class="p">(</span><span class="n">Cifar10DatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;cifar10&quot;</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span></div>
  370. <div class="viewcode-block" id="Cifar100DatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.Cifar100DatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">Cifar100DatasetInterface</span><span class="p">(</span><span class="n">LibraryDatasetInterface</span><span class="p">):</span>
  371. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{}):</span>
  372. <span class="nb">super</span><span class="p">(</span><span class="n">Cifar100DatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;cifar100&quot;</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span></div>
  373. <div class="viewcode-block" id="TestDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.TestDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">TestDatasetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
  374. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">trainset</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">classes</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  375. <span class="nb">super</span><span class="p">(</span><span class="n">TestDatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">)</span>
  376. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">trainset</span>
  377. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span>
  378. <span class="bp">self</span><span class="o">.</span><span class="n">testset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span>
  379. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">classes</span>
  380. <div class="viewcode-block" id="TestDatasetInterface.get_data_loaders"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.TestDatasetInterface.get_data_loaders">[docs]</a> <span class="k">def</span> <span class="nf">get_data_loaders</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size_factor</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">train_batch_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">val_batch_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
  381. <span class="n">distributed_sampler</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
  382. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">classes</span>
  383. <span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">get_data_loaders</span><span class="p">(</span><span class="n">batch_size_factor</span><span class="o">=</span><span class="n">batch_size_factor</span><span class="p">,</span>
  384. <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span>
  385. <span class="n">train_batch_size</span><span class="o">=</span><span class="n">train_batch_size</span><span class="p">,</span>
  386. <span class="n">val_batch_size</span><span class="o">=</span><span class="n">val_batch_size</span><span class="p">,</span>
  387. <span class="n">distributed_sampler</span><span class="o">=</span><span class="n">distributed_sampler</span><span class="p">)</span></div></div>
  388. <div class="viewcode-block" id="ClassificationTestDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.ClassificationTestDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">ClassificationTestDatasetInterface</span><span class="p">(</span><span class="n">TestDatasetInterface</span><span class="p">):</span>
  389. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">image_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">classes</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  390. <span class="n">trainset</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">TensorDataset</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">))),</span>
  391. <span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">))))</span>
  392. <span class="nb">super</span><span class="p">(</span><span class="n">ClassificationTestDatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">trainset</span><span class="o">=</span><span class="n">trainset</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  393. <span class="n">classes</span><span class="o">=</span><span class="n">classes</span><span class="p">)</span></div>
  394. <div class="viewcode-block" id="SegmentationTestDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.SegmentationTestDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">SegmentationTestDatasetInterface</span><span class="p">(</span><span class="n">TestDatasetInterface</span><span class="p">):</span>
  395. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">image_size</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">4</span><span class="p">):</span>
  396. <span class="n">trainset</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">TensorDataset</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">))),</span>
  397. <span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">))))</span>
  398. <span class="nb">super</span><span class="p">(</span><span class="n">SegmentationTestDatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">trainset</span><span class="o">=</span><span class="n">trainset</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span></div>
  399. <div class="viewcode-block" id="DetectionTestDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.DetectionTestDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">DetectionTestDatasetInterface</span><span class="p">(</span><span class="n">TestDatasetInterface</span><span class="p">):</span>
  400. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">image_size</span><span class="o">=</span><span class="mi">320</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">classes</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  401. <span class="n">trainset</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">TensorDataset</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">))),</span>
  402. <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">6</span><span class="p">))))</span>
  403. <span class="nb">super</span><span class="p">(</span><span class="n">DetectionTestDatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">trainset</span><span class="o">=</span><span class="n">trainset</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  404. <span class="n">classes</span><span class="o">=</span><span class="n">classes</span><span class="p">)</span></div>
  405. <div class="viewcode-block" id="TestYoloDetectionDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.TestYoloDetectionDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">TestYoloDetectionDatasetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
  406. <span class="sd">&quot;&quot;&quot;</span>
  407. <span class="sd"> note: the output size is (batch_size, 6) in the test while in real training</span>
  408. <span class="sd"> the size of axis 0 can vary (the number of bounding boxes)</span>
  409. <span class="sd"> &quot;&quot;&quot;</span>
  410. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">input_dims</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">),</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">5</span><span class="p">):</span>
  411. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">)</span>
  412. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">TensorDataset</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">*</span><span class="n">input_dims</span><span class="p">)),</span>
  413. <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">6</span><span class="p">)))</span>
  414. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span>
  415. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span></div>
  416. <div class="viewcode-block" id="ImageNetDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.ImageNetDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">ImageNetDatasetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
  417. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">data_dir</span><span class="o">=</span><span class="s2">&quot;/data/Imagenet&quot;</span><span class="p">):</span>
  418. <span class="nb">super</span><span class="p">(</span><span class="n">ImageNetDatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">)</span>
  419. <span class="n">data_dir</span> <span class="o">=</span> <span class="n">dataset_params</span><span class="p">[</span><span class="s1">&#39;dataset_dir&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="s1">&#39;dataset_dir&#39;</span> <span class="ow">in</span> <span class="n">dataset_params</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> <span class="k">else</span> <span class="n">data_dir</span>
  420. <span class="n">traindir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">data_dir</span><span class="p">),</span> <span class="s1">&#39;train&#39;</span><span class="p">)</span>
  421. <span class="n">valdir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="s1">&#39;val&#39;</span><span class="p">)</span>
  422. <span class="n">img_mean</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;img_mean&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="p">[</span><span class="mf">0.485</span><span class="p">,</span> <span class="mf">0.456</span><span class="p">,</span> <span class="mf">0.406</span><span class="p">])</span>
  423. <span class="n">img_std</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;img_std&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="p">[</span><span class="mf">0.229</span><span class="p">,</span> <span class="mf">0.224</span><span class="p">,</span> <span class="mf">0.225</span><span class="p">])</span>
  424. <span class="n">normalize</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="n">img_mean</span><span class="p">,</span>
  425. <span class="n">std</span><span class="o">=</span><span class="n">img_std</span><span class="p">)</span>
  426. <span class="n">crop_size</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;crop_size&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="mi">224</span><span class="p">)</span>
  427. <span class="n">resize_size</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;resize_size&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="mi">256</span><span class="p">)</span>
  428. <span class="n">color_jitter</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;color_jitter&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
  429. <span class="n">imagenet_pca_aug</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;imagenet_pca_aug&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="mf">0.0</span><span class="p">)</span>
  430. <span class="n">train_interpolation</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;train_interpolation&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="s1">&#39;default&#39;</span><span class="p">)</span>
  431. <span class="n">rand_augment_config_string</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;rand_augment_config_string&#39;</span><span class="p">,</span>
  432. <span class="n">default_val</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
  433. <span class="n">color_jitter</span> <span class="o">=</span> <span class="p">(</span><span class="nb">float</span><span class="p">(</span><span class="n">color_jitter</span><span class="p">),)</span> <span class="o">*</span> <span class="mi">3</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">color_jitter</span><span class="p">,</span> <span class="nb">float</span><span class="p">)</span> <span class="k">else</span> <span class="n">color_jitter</span>
  434. <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">color_jitter</span><span class="p">)</span> <span class="ow">in</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span> <span class="s2">&quot;color_jitter must be a scalar or tuple of len 3 or 4&quot;</span>
  435. <span class="n">color_augmentation</span> <span class="o">=</span> <span class="n">datasets_utils</span><span class="o">.</span><span class="n">get_color_augmentation</span><span class="p">(</span><span class="n">rand_augment_config_string</span><span class="p">,</span> <span class="n">color_jitter</span><span class="p">,</span>
  436. <span class="n">crop_size</span><span class="o">=</span><span class="n">crop_size</span><span class="p">,</span> <span class="n">img_mean</span><span class="o">=</span><span class="n">img_mean</span><span class="p">)</span>
  437. <span class="n">train_transformation_list</span> <span class="o">=</span> <span class="p">[</span>
  438. <span class="n">RandomResizedCropAndInterpolation</span><span class="p">(</span><span class="n">crop_size</span><span class="p">,</span> <span class="n">interpolation</span><span class="o">=</span><span class="n">train_interpolation</span><span class="p">),</span>
  439. <span class="n">transforms</span><span class="o">.</span><span class="n">RandomHorizontalFlip</span><span class="p">(),</span>
  440. <span class="n">color_augmentation</span><span class="p">,</span>
  441. <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
  442. <span class="n">Lighting</span><span class="p">(</span><span class="n">imagenet_pca_aug</span><span class="p">),</span>
  443. <span class="n">normalize</span><span class="p">]</span>
  444. <span class="n">rndm_erase_prob</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;random_erase_prob&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="mf">0.</span><span class="p">)</span>
  445. <span class="k">if</span> <span class="n">rndm_erase_prob</span><span class="p">:</span>
  446. <span class="n">train_transformation_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">RandomErase</span><span class="p">(</span><span class="n">rndm_erase_prob</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">random_erase_value</span><span class="p">))</span>
  447. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">datasets</span><span class="o">.</span><span class="n">ImageFolder</span><span class="p">(</span><span class="n">traindir</span><span class="p">,</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">(</span><span class="n">train_transformation_list</span><span class="p">))</span>
  448. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">datasets</span><span class="o">.</span><span class="n">ImageFolder</span><span class="p">(</span><span class="n">valdir</span><span class="p">,</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
  449. <span class="n">transforms</span><span class="o">.</span><span class="n">Resize</span><span class="p">(</span><span class="n">resize_size</span><span class="p">),</span>
  450. <span class="n">transforms</span><span class="o">.</span><span class="n">CenterCrop</span><span class="p">(</span><span class="n">crop_size</span><span class="p">),</span>
  451. <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
  452. <span class="n">normalize</span><span class="p">,</span>
  453. <span class="p">]))</span></div>
  454. <div class="viewcode-block" id="TinyImageNetDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.TinyImageNetDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">TinyImageNetDatasetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
  455. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">data_dir</span><span class="o">=</span><span class="s2">&quot;/data/TinyImagenet&quot;</span><span class="p">):</span>
  456. <span class="nb">super</span><span class="p">(</span><span class="n">TinyImageNetDatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">)</span>
  457. <span class="n">data_dir</span> <span class="o">=</span> <span class="n">dataset_params</span><span class="p">[</span><span class="s1">&#39;dataset_dir&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="s1">&#39;dataset_dir&#39;</span> <span class="ow">in</span> <span class="n">dataset_params</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> <span class="k">else</span> <span class="n">data_dir</span>
  458. <span class="n">traindir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">data_dir</span><span class="p">),</span> <span class="s1">&#39;train&#39;</span><span class="p">)</span>
  459. <span class="n">valdir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="s1">&#39;val&#39;</span><span class="p">)</span>
  460. <span class="n">img_mean</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;img_mean&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="p">[</span><span class="mf">0.4802</span><span class="p">,</span> <span class="mf">0.4481</span><span class="p">,</span> <span class="mf">0.3975</span><span class="p">])</span>
  461. <span class="n">img_std</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;img_std&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="p">[</span><span class="mf">0.2770</span><span class="p">,</span> <span class="mf">0.2691</span><span class="p">,</span> <span class="mf">0.2821</span><span class="p">])</span>
  462. <span class="n">normalize</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="n">img_mean</span><span class="p">,</span>
  463. <span class="n">std</span><span class="o">=</span><span class="n">img_std</span><span class="p">)</span>
  464. <span class="n">crop_size</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;crop_size&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="mi">56</span><span class="p">)</span>
  465. <span class="n">resize_size</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;resize_size&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="mi">64</span><span class="p">)</span>
  466. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">datasets</span><span class="o">.</span><span class="n">ImageFolder</span><span class="p">(</span>
  467. <span class="n">traindir</span><span class="p">,</span>
  468. <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
  469. <span class="n">transforms</span><span class="o">.</span><span class="n">RandomResizedCrop</span><span class="p">(</span><span class="n">crop_size</span><span class="p">),</span>
  470. <span class="n">transforms</span><span class="o">.</span><span class="n">RandomHorizontalFlip</span><span class="p">(),</span>
  471. <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
  472. <span class="n">normalize</span><span class="p">,</span>
  473. <span class="p">]))</span>
  474. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">datasets</span><span class="o">.</span><span class="n">ImageFolder</span><span class="p">(</span><span class="n">valdir</span><span class="p">,</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
  475. <span class="n">transforms</span><span class="o">.</span><span class="n">Resize</span><span class="p">(</span><span class="n">resize_size</span><span class="p">),</span>
  476. <span class="n">transforms</span><span class="o">.</span><span class="n">CenterCrop</span><span class="p">(</span><span class="n">crop_size</span><span class="p">),</span>
  477. <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
  478. <span class="n">normalize</span><span class="p">,</span>
  479. <span class="p">]))</span></div>
  480. <div class="viewcode-block" id="ClassificationDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.ClassificationDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">ClassificationDatasetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
  481. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">normalization_mean</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="n">normalization_std</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">resolution</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span>
  482. <span class="n">dataset_params</span><span class="o">=</span><span class="p">{}):</span>
  483. <span class="nb">super</span><span class="p">(</span><span class="n">ClassificationDatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">)</span>
  484. <span class="n">data_dir</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">dataset_dir</span>
  485. <span class="n">traindir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">data_dir</span><span class="p">),</span> <span class="s1">&#39;train&#39;</span><span class="p">)</span>
  486. <span class="n">valdir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="s1">&#39;val&#39;</span><span class="p">)</span>
  487. <span class="n">normalize</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="n">normalization_mean</span><span class="p">,</span>
  488. <span class="n">std</span><span class="o">=</span><span class="n">normalization_std</span><span class="p">)</span>
  489. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">datasets</span><span class="o">.</span><span class="n">ImageFolder</span><span class="p">(</span>
  490. <span class="n">traindir</span><span class="p">,</span>
  491. <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
  492. <span class="n">transforms</span><span class="o">.</span><span class="n">RandomResizedCrop</span><span class="p">(</span><span class="n">resolution</span><span class="p">),</span>
  493. <span class="n">transforms</span><span class="o">.</span><span class="n">RandomHorizontalFlip</span><span class="p">(),</span>
  494. <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
  495. <span class="n">normalize</span><span class="p">,</span>
  496. <span class="p">]))</span>
  497. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">datasets</span><span class="o">.</span><span class="n">ImageFolder</span><span class="p">(</span><span class="n">valdir</span><span class="p">,</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
  498. <span class="n">transforms</span><span class="o">.</span><span class="n">Resize</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">resolution</span> <span class="o">*</span> <span class="mf">1.15</span><span class="p">)),</span>
  499. <span class="n">transforms</span><span class="o">.</span><span class="n">CenterCrop</span><span class="p">(</span><span class="n">resolution</span><span class="p">),</span>
  500. <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
  501. <span class="n">normalize</span><span class="p">,</span>
  502. <span class="p">]))</span>
  503. <span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span> <span class="o">=</span> <span class="n">data_dir</span>
  504. <span class="bp">self</span><span class="o">.</span><span class="n">normalization_mean</span> <span class="o">=</span> <span class="n">normalization_mean</span>
  505. <span class="bp">self</span><span class="o">.</span><span class="n">normalization_std</span> <span class="o">=</span> <span class="n">normalization_std</span></div>
  506. <div class="viewcode-block" id="PascalVOC2012SegmentationDataSetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.PascalVOC2012SegmentationDataSetInterface">[docs]</a><span class="k">class</span> <span class="nc">PascalVOC2012SegmentationDataSetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
  507. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">cache_labels</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">cache_images</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
  508. <span class="k">if</span> <span class="n">dataset_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  509. <span class="n">dataset_params</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
  510. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span>
  511. <span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span> <span class="o">=</span> <span class="n">dataset_params</span><span class="p">[</span><span class="s1">&#39;dataset_dir&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="s1">&#39;dataset_dir&#39;</span> <span class="ow">in</span> <span class="n">dataset_params</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> \
  512. <span class="k">else</span> <span class="s1">&#39;/data/pascal_voc_2012/VOCdevkit/VOC2012/&#39;</span>
  513. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">PascalVOC2012SegmentationDataSet</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span><span class="p">,</span>
  514. <span class="n">list_file</span><span class="o">=</span><span class="s1">&#39;ImageSets/Segmentation/train.txt&#39;</span><span class="p">,</span>
  515. <span class="n">samples_sub_directory</span><span class="o">=</span><span class="s1">&#39;JPEGImages&#39;</span><span class="p">,</span>
  516. <span class="n">targets_sub_directory</span><span class="o">=</span><span class="s1">&#39;SegmentationClass&#39;</span><span class="p">,</span> <span class="n">augment</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  517. <span class="n">dataset_hyper_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span> <span class="n">cache_labels</span><span class="o">=</span><span class="n">cache_labels</span><span class="p">,</span>
  518. <span class="n">cache_images</span><span class="o">=</span><span class="n">cache_images</span><span class="p">)</span>
  519. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">PascalVOC2012SegmentationDataSet</span><span class="p">(</span><span class="n">root</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span><span class="p">,</span>
  520. <span class="n">list_file</span><span class="o">=</span><span class="s1">&#39;ImageSets/Segmentation/val.txt&#39;</span><span class="p">,</span>
  521. <span class="n">samples_sub_directory</span><span class="o">=</span><span class="s1">&#39;JPEGImages&#39;</span><span class="p">,</span>
  522. <span class="n">targets_sub_directory</span><span class="o">=</span><span class="s1">&#39;SegmentationClass&#39;</span><span class="p">,</span> <span class="n">augment</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  523. <span class="n">dataset_hyper_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span> <span class="n">cache_labels</span><span class="o">=</span><span class="n">cache_labels</span><span class="p">,</span>
  524. <span class="n">cache_images</span><span class="o">=</span><span class="n">cache_images</span><span class="p">)</span>
  525. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">classes</span></div>
  526. <div class="viewcode-block" id="PascalAUG2012SegmentationDataSetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.PascalAUG2012SegmentationDataSetInterface">[docs]</a><span class="k">class</span> <span class="nc">PascalAUG2012SegmentationDataSetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
  527. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">cache_labels</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">cache_images</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
  528. <span class="k">if</span> <span class="n">dataset_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  529. <span class="n">dataset_params</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
  530. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span>
  531. <span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span> <span class="o">=</span> <span class="n">dataset_params</span><span class="p">[</span><span class="s1">&#39;dataset_dir&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="s1">&#39;dataset_dir&#39;</span> <span class="ow">in</span> <span class="n">dataset_params</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> \
  532. <span class="k">else</span> <span class="s1">&#39;/data/pascal_voc_2012/VOCaug/dataset/&#39;</span>
  533. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">PascalAUG2012SegmentationDataSet</span><span class="p">(</span>
  534. <span class="n">root</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span><span class="p">,</span>
  535. <span class="n">list_file</span><span class="o">=</span><span class="s1">&#39;trainval.txt&#39;</span><span class="p">,</span>
  536. <span class="n">samples_sub_directory</span><span class="o">=</span><span class="s1">&#39;img&#39;</span><span class="p">,</span>
  537. <span class="n">targets_sub_directory</span><span class="o">=</span><span class="s1">&#39;cls&#39;</span><span class="p">,</span> <span class="n">augment</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  538. <span class="n">dataset_hyper_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span> <span class="n">cache_labels</span><span class="o">=</span><span class="n">cache_labels</span><span class="p">,</span>
  539. <span class="n">cache_images</span><span class="o">=</span><span class="n">cache_images</span><span class="p">)</span>
  540. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">PascalAUG2012SegmentationDataSet</span><span class="p">(</span>
  541. <span class="n">root</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span><span class="p">,</span>
  542. <span class="n">list_file</span><span class="o">=</span><span class="s1">&#39;val.txt&#39;</span><span class="p">,</span>
  543. <span class="n">samples_sub_directory</span><span class="o">=</span><span class="s1">&#39;img&#39;</span><span class="p">,</span>
  544. <span class="n">targets_sub_directory</span><span class="o">=</span><span class="s1">&#39;cls&#39;</span><span class="p">,</span> <span class="n">augment</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  545. <span class="n">dataset_hyper_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span> <span class="n">cache_labels</span><span class="o">=</span><span class="n">cache_labels</span><span class="p">,</span>
  546. <span class="n">cache_images</span><span class="o">=</span><span class="n">cache_images</span><span class="p">)</span>
  547. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">classes</span></div>
  548. <div class="viewcode-block" id="CoCoDataSetInterfaceBase"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.CoCoDataSetInterfaceBase">[docs]</a><span class="k">class</span> <span class="nc">CoCoDataSetInterfaceBase</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
  549. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  550. <span class="k">if</span> <span class="n">dataset_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  551. <span class="n">dataset_params</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
  552. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span>
  553. <span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span> <span class="o">=</span> <span class="n">dataset_params</span><span class="p">[</span><span class="s1">&#39;dataset_dir&#39;</span><span class="p">]</span> <span class="k">if</span> <span class="s1">&#39;dataset_dir&#39;</span> <span class="ow">in</span> <span class="n">dataset_params</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> <span class="k">else</span> <span class="s1">&#39;/data/coco/&#39;</span></div>
  554. <div class="viewcode-block" id="CoCoSegmentationDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.CoCoSegmentationDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">CoCoSegmentationDatasetInterface</span><span class="p">(</span><span class="n">CoCoDataSetInterfaceBase</span><span class="p">):</span>
  555. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">cache_labels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">cache_images</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
  556. <span class="n">dataset_classes_inclusion_tuples_list</span><span class="p">:</span> <span class="nb">list</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  557. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span>
  558. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">CoCoSegmentationDataSet</span><span class="p">(</span>
  559. <span class="n">root</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span><span class="p">,</span>
  560. <span class="n">list_file</span><span class="o">=</span><span class="s1">&#39;instances_train2017.json&#39;</span><span class="p">,</span>
  561. <span class="n">samples_sub_directory</span><span class="o">=</span><span class="s1">&#39;images/train2017&#39;</span><span class="p">,</span>
  562. <span class="n">targets_sub_directory</span><span class="o">=</span><span class="s1">&#39;annotations&#39;</span><span class="p">,</span> <span class="n">augment</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  563. <span class="n">dataset_hyper_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  564. <span class="n">cache_labels</span><span class="o">=</span><span class="n">cache_labels</span><span class="p">,</span>
  565. <span class="n">cache_images</span><span class="o">=</span><span class="n">cache_images</span><span class="p">,</span>
  566. <span class="n">dataset_classes_inclusion_tuples_list</span><span class="o">=</span><span class="n">dataset_classes_inclusion_tuples_list</span><span class="p">)</span>
  567. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">CoCoSegmentationDataSet</span><span class="p">(</span>
  568. <span class="n">root</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span><span class="p">,</span>
  569. <span class="n">list_file</span><span class="o">=</span><span class="s1">&#39;instances_val2017.json&#39;</span><span class="p">,</span>
  570. <span class="n">samples_sub_directory</span><span class="o">=</span><span class="s1">&#39;images/val2017&#39;</span><span class="p">,</span>
  571. <span class="n">targets_sub_directory</span><span class="o">=</span><span class="s1">&#39;annotations&#39;</span><span class="p">,</span> <span class="n">augment</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  572. <span class="n">dataset_hyper_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  573. <span class="n">cache_labels</span><span class="o">=</span><span class="n">cache_labels</span><span class="p">,</span>
  574. <span class="n">cache_images</span><span class="o">=</span><span class="n">cache_images</span><span class="p">,</span>
  575. <span class="n">dataset_classes_inclusion_tuples_list</span><span class="o">=</span><span class="n">dataset_classes_inclusion_tuples_list</span><span class="p">)</span>
  576. <span class="bp">self</span><span class="o">.</span><span class="n">coco_classes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">classes</span></div>
  577. <div class="viewcode-block" id="CityscapesDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.CityscapesDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">CityscapesDatasetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
  578. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">cache_labels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">cache_images</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
  579. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span>
  580. <span class="n">root_dir</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;dataset_dir&quot;</span><span class="p">,</span> <span class="s2">&quot;/data/cityscapes&quot;</span><span class="p">)</span>
  581. <span class="n">img_size</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;img_size&quot;</span><span class="p">,</span> <span class="mi">1024</span><span class="p">)</span>
  582. <span class="n">crop_size</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;crop_size&quot;</span><span class="p">,</span> <span class="mi">512</span><span class="p">)</span>
  583. <span class="n">image_mask_transforms</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;image_mask_transforms&quot;</span><span class="p">)</span>
  584. <span class="n">image_mask_transforms_aug</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;image_mask_transforms_aug&quot;</span><span class="p">)</span>
  585. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">CityscapesDataset</span><span class="p">(</span>
  586. <span class="n">root_dir</span><span class="o">=</span><span class="n">root_dir</span><span class="p">,</span>
  587. <span class="n">list_file</span><span class="o">=</span><span class="s1">&#39;lists/train.lst&#39;</span><span class="p">,</span>
  588. <span class="n">labels_csv_path</span><span class="o">=</span><span class="s2">&quot;lists/labels.csv&quot;</span><span class="p">,</span>
  589. <span class="n">img_size</span><span class="o">=</span><span class="n">img_size</span><span class="p">,</span>
  590. <span class="n">crop_size</span><span class="o">=</span><span class="n">crop_size</span><span class="p">,</span>
  591. <span class="n">augment</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  592. <span class="n">dataset_hyper_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  593. <span class="n">cache_labels</span><span class="o">=</span><span class="n">cache_labels</span><span class="p">,</span>
  594. <span class="n">cache_images</span><span class="o">=</span><span class="n">cache_images</span><span class="p">,</span>
  595. <span class="n">image_mask_transforms</span><span class="o">=</span><span class="n">image_mask_transforms</span><span class="p">,</span>
  596. <span class="n">image_mask_transforms_aug</span><span class="o">=</span><span class="n">image_mask_transforms_aug</span><span class="p">)</span>
  597. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">CityscapesDataset</span><span class="p">(</span>
  598. <span class="n">root_dir</span><span class="o">=</span><span class="n">root_dir</span><span class="p">,</span>
  599. <span class="n">list_file</span><span class="o">=</span><span class="s1">&#39;lists/val.lst&#39;</span><span class="p">,</span>
  600. <span class="n">labels_csv_path</span><span class="o">=</span><span class="s2">&quot;lists/labels.csv&quot;</span><span class="p">,</span>
  601. <span class="n">img_size</span><span class="o">=</span><span class="n">img_size</span><span class="p">,</span>
  602. <span class="n">crop_size</span><span class="o">=</span><span class="n">crop_size</span><span class="p">,</span>
  603. <span class="n">augment</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  604. <span class="n">dataset_hyper_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  605. <span class="n">cache_labels</span><span class="o">=</span><span class="n">cache_labels</span><span class="p">,</span>
  606. <span class="n">cache_images</span><span class="o">=</span><span class="n">cache_images</span><span class="p">,</span>
  607. <span class="n">image_mask_transforms</span><span class="o">=</span><span class="n">image_mask_transforms</span><span class="p">)</span>
  608. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">classes</span></div>
  609. <div class="viewcode-block" id="SuperviselyPersonsDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.SuperviselyPersonsDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">SuperviselyPersonsDatasetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
  610. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">cache_labels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">cache_images</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
  611. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span>
  612. <span class="n">root_dir</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;dataset_dir&quot;</span><span class="p">,</span> <span class="s2">&quot;/data/supervisely-persons&quot;</span><span class="p">)</span>
  613. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">SuperviselyPersonsDataset</span><span class="p">(</span>
  614. <span class="n">root_dir</span><span class="o">=</span><span class="n">root_dir</span><span class="p">,</span>
  615. <span class="n">list_file</span><span class="o">=</span><span class="s1">&#39;train.csv&#39;</span><span class="p">,</span>
  616. <span class="n">dataset_hyper_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  617. <span class="n">cache_labels</span><span class="o">=</span><span class="n">cache_labels</span><span class="p">,</span>
  618. <span class="n">cache_images</span><span class="o">=</span><span class="n">cache_images</span><span class="p">,</span>
  619. <span class="n">image_mask_transforms_aug</span><span class="o">=</span><span class="n">get_param</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;image_mask_transforms_aug&quot;</span><span class="p">,</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([])),</span>
  620. <span class="n">augment</span><span class="o">=</span><span class="kc">True</span>
  621. <span class="p">)</span>
  622. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">SuperviselyPersonsDataset</span><span class="p">(</span>
  623. <span class="n">root_dir</span><span class="o">=</span><span class="n">root_dir</span><span class="p">,</span>
  624. <span class="n">list_file</span><span class="o">=</span><span class="s1">&#39;val.csv&#39;</span><span class="p">,</span>
  625. <span class="n">dataset_hyper_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">,</span>
  626. <span class="n">cache_labels</span><span class="o">=</span><span class="n">cache_labels</span><span class="p">,</span>
  627. <span class="n">cache_images</span><span class="o">=</span><span class="n">cache_images</span><span class="p">,</span>
  628. <span class="n">image_mask_transforms</span><span class="o">=</span><span class="n">get_param</span><span class="p">(</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;image_mask_transforms&quot;</span><span class="p">,</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([])),</span>
  629. <span class="n">augment</span><span class="o">=</span><span class="kc">False</span>
  630. <span class="p">)</span>
  631. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">classes</span></div>
  632. <div class="viewcode-block" id="DetectionDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.DetectionDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">DetectionDatasetInterface</span><span class="p">(</span><span class="n">DatasetInterface</span><span class="p">):</span>
  633. <div class="viewcode-block" id="DetectionDatasetInterface.build_data_loaders"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.DetectionDatasetInterface.build_data_loaders">[docs]</a> <span class="k">def</span> <span class="nf">build_data_loaders</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size_factor</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">train_batch_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">val_batch_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
  634. <span class="n">test_batch_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">distributed_sampler</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
  635. <span class="n">train_sampler</span> <span class="o">=</span> <span class="n">InfiniteSampler</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="p">),</span> <span class="n">seed</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
  636. <span class="n">train_batch_sampler</span> <span class="o">=</span> <span class="n">BatchSampler</span><span class="p">(</span>
  637. <span class="n">sampler</span><span class="o">=</span><span class="n">train_sampler</span><span class="p">,</span>
  638. <span class="n">batch_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span>
  639. <span class="n">drop_last</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  640. <span class="p">)</span>
  641. <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="p">,</span>
  642. <span class="n">batch_sampler</span><span class="o">=</span><span class="n">train_batch_sampler</span><span class="p">,</span>
  643. <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span>
  644. <span class="n">pin_memory</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  645. <span class="n">worker_init_fn</span><span class="o">=</span><span class="n">worker_init_reset_seed</span><span class="p">,</span>
  646. <span class="n">collate_fn</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_collate_fn</span><span class="p">)</span>
  647. <span class="k">if</span> <span class="n">distributed_sampler</span><span class="p">:</span>
  648. <span class="n">sampler</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">DistributedSampler</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
  649. <span class="k">else</span><span class="p">:</span>
  650. <span class="n">sampler</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">SequentialSampler</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">)</span>
  651. <span class="n">val_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valset</span><span class="p">,</span>
  652. <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span>
  653. <span class="n">pin_memory</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  654. <span class="n">sampler</span><span class="o">=</span><span class="n">sampler</span><span class="p">,</span>
  655. <span class="n">batch_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_batch_size</span><span class="p">,</span>
  656. <span class="n">collate_fn</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_collate_fn</span><span class="p">)</span>
  657. <span class="bp">self</span><span class="o">.</span><span class="n">val_loader</span> <span class="o">=</span> <span class="n">val_loader</span></div></div>
  658. <div class="viewcode-block" id="PascalVOCUnifiedDetectionDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.PascalVOCUnifiedDetectionDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">PascalVOCUnifiedDetectionDatasetInterface</span><span class="p">(</span><span class="n">DetectionDatasetInterface</span><span class="p">):</span>
  659. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  660. <span class="k">if</span> <span class="n">dataset_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  661. <span class="n">dataset_params</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
  662. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span>
  663. <span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">data_dir</span>
  664. <span class="n">train_input_dim</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_image_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_image_size</span><span class="p">)</span>
  665. <span class="n">val_input_dim</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_image_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_image_size</span><span class="p">)</span>
  666. <span class="n">train_max_num_samples</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;train_max_num_samples&quot;</span><span class="p">)</span>
  667. <span class="n">val_max_num_samples</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;val_max_num_samples&quot;</span><span class="p">)</span>
  668. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">download</span><span class="p">:</span>
  669. <span class="n">PascalVOCDetectionDataset</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">data_dir</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="p">)</span>
  670. <span class="n">train_dataset_names</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;train2007&quot;</span><span class="p">,</span> <span class="s2">&quot;val2007&quot;</span><span class="p">,</span> <span class="s2">&quot;train2012&quot;</span><span class="p">,</span> <span class="s2">&quot;val2012&quot;</span><span class="p">]</span>
  671. <span class="c1"># We divide train_max_num_samples between the datasets</span>
  672. <span class="k">if</span> <span class="n">train_max_num_samples</span><span class="p">:</span>
  673. <span class="n">max_num_samples_per_train_dataset</span> <span class="o">=</span> <span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="n">segment</span><span class="p">)</span> <span class="k">for</span> <span class="n">segment</span> <span class="ow">in</span> <span class="n">np</span><span class="o">.</span><span class="n">array_split</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">train_max_num_samples</span><span class="p">),</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_dataset_names</span><span class="p">))]</span>
  674. <span class="k">else</span><span class="p">:</span>
  675. <span class="n">max_num_samples_per_train_dataset</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_dataset_names</span><span class="p">)</span>
  676. <span class="n">train_sets</span> <span class="o">=</span> <span class="p">[</span><span class="n">PascalVOCDetectionDataset</span><span class="p">(</span><span class="n">data_dir</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="p">,</span>
  677. <span class="n">input_dim</span><span class="o">=</span><span class="n">train_input_dim</span><span class="p">,</span>
  678. <span class="n">cache</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">cache_train_images</span><span class="p">,</span>
  679. <span class="n">cache_path</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">cache_dir</span> <span class="o">+</span> <span class="s2">&quot;cache_train&quot;</span><span class="p">,</span>
  680. <span class="n">transforms</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_transforms</span><span class="p">,</span>
  681. <span class="n">images_sub_directory</span><span class="o">=</span><span class="s1">&#39;images/&#39;</span> <span class="o">+</span> <span class="n">trainset_name</span> <span class="o">+</span> <span class="s1">&#39;/&#39;</span><span class="p">,</span>
  682. <span class="n">class_inclusion_list</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">class_inclusion_list</span><span class="p">,</span>
  683. <span class="n">max_num_samples</span><span class="o">=</span><span class="n">max_num_samples_per_train_dataset</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
  684. <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">trainset_name</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">train_dataset_names</span><span class="p">)]</span>
  685. <span class="n">testset2007</span> <span class="o">=</span> <span class="n">PascalVOCDetectionDataset</span><span class="p">(</span><span class="n">data_dir</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="p">,</span>
  686. <span class="n">input_dim</span><span class="o">=</span><span class="n">val_input_dim</span><span class="p">,</span>
  687. <span class="n">cache</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">cache_val_images</span><span class="p">,</span>
  688. <span class="n">cache_path</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">cache_dir</span> <span class="o">+</span> <span class="s2">&quot;cache_valid&quot;</span><span class="p">,</span>
  689. <span class="n">transforms</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_transforms</span><span class="p">,</span>
  690. <span class="n">images_sub_directory</span><span class="o">=</span><span class="s1">&#39;images/test2007/&#39;</span><span class="p">,</span>
  691. <span class="n">class_inclusion_list</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">class_inclusion_list</span><span class="p">,</span>
  692. <span class="n">max_num_samples</span><span class="o">=</span><span class="n">val_max_num_samples</span><span class="p">)</span>
  693. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">train_sets</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">classes</span>
  694. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">ConcatDataset</span><span class="p">(</span><span class="n">train_sets</span><span class="p">)</span>
  695. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">testset2007</span>
  696. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">collate_fn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_collate_fn</span>
  697. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">classes</span>
  698. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">img_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_image_size</span>
  699. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span><span class="o">.</span><span class="n">cache_labels</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">cache_train_images</span></div>
  700. <div class="viewcode-block" id="CoCoDetectionDatasetInterface"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.dataset_interfaces.html#super_gradients.training.CoCoDetectionDatasetInterface">[docs]</a><span class="k">class</span> <span class="nc">CoCoDetectionDatasetInterface</span><span class="p">(</span><span class="n">DetectionDatasetInterface</span><span class="p">):</span>
  701. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="p">{}):</span>
  702. <span class="nb">super</span><span class="p">(</span><span class="n">CoCoDetectionDatasetInterface</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dataset_params</span><span class="o">=</span><span class="n">dataset_params</span><span class="p">)</span>
  703. <span class="n">train_input_dim</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_image_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_image_size</span><span class="p">)</span>
  704. <span class="n">targets_format</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s2">&quot;targets_format&quot;</span><span class="p">,</span> <span class="n">DetectionTargetsFormat</span><span class="o">.</span><span class="n">LABEL_CXCYWH</span><span class="p">)</span>
  705. <span class="n">train_transforms</span> <span class="o">=</span> <span class="p">[</span><span class="n">DetectionMosaic</span><span class="p">(</span><span class="n">input_dim</span><span class="o">=</span><span class="n">train_input_dim</span><span class="p">,</span>
  706. <span class="n">prob</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">mosaic_prob</span><span class="p">),</span>
  707. <span class="n">DetectionRandomAffine</span><span class="p">(</span><span class="n">degrees</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">degrees</span><span class="p">,</span>
  708. <span class="n">translate</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">translate</span><span class="p">,</span>
  709. <span class="n">scales</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">mosaic_scale</span><span class="p">,</span>
  710. <span class="n">shear</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">shear</span><span class="p">,</span>
  711. <span class="n">target_size</span><span class="o">=</span><span class="n">train_input_dim</span><span class="p">,</span>
  712. <span class="n">filter_box_candidates</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">filter_box_candidates</span><span class="p">,</span>
  713. <span class="n">wh_thr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">wh_thr</span><span class="p">,</span>
  714. <span class="n">area_thr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">area_thr</span><span class="p">,</span>
  715. <span class="n">ar_thr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">ar_thr</span>
  716. <span class="p">),</span>
  717. <span class="n">DetectionMixup</span><span class="p">(</span><span class="n">input_dim</span><span class="o">=</span><span class="n">train_input_dim</span><span class="p">,</span>
  718. <span class="n">mixup_scale</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">mixup_scale</span><span class="p">,</span>
  719. <span class="n">prob</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">mixup_prob</span><span class="p">,</span>
  720. <span class="n">flip_prob</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">flip_prob</span><span class="p">),</span>
  721. <span class="n">DetectionHSV</span><span class="p">(</span><span class="n">prob</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">hsv_prob</span><span class="p">,</span>
  722. <span class="n">hgain</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">hgain</span><span class="p">,</span>
  723. <span class="n">sgain</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">sgain</span><span class="p">,</span>
  724. <span class="n">vgain</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">vgain</span>
  725. <span class="p">),</span>
  726. <span class="n">DetectionHorizontalFlip</span><span class="p">(</span><span class="n">prob</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">flip_prob</span><span class="p">),</span>
  727. <span class="n">DetectionPaddedRescale</span><span class="p">(</span><span class="n">input_dim</span><span class="o">=</span><span class="n">train_input_dim</span><span class="p">,</span> <span class="n">max_targets</span><span class="o">=</span><span class="mi">120</span><span class="p">),</span>
  728. <span class="n">DetectionTargetsFormatTransform</span><span class="p">(</span><span class="n">output_format</span><span class="o">=</span><span class="n">targets_format</span><span class="p">)</span>
  729. <span class="p">]</span>
  730. <span class="c1"># IF CACHE- CREATING THE CACHE FILE WILL HAPPEN ONLY FOR RANK 0, THEN ALL THE OTHER RANKS SIMPLY READ FROM IT.</span>
  731. <span class="n">local_rank</span> <span class="o">=</span> <span class="n">get_local_rank</span><span class="p">()</span>
  732. <span class="k">with</span> <span class="n">wait_for_the_master</span><span class="p">(</span><span class="n">local_rank</span><span class="p">):</span>
  733. <span class="bp">self</span><span class="o">.</span><span class="n">trainset</span> <span class="o">=</span> <span class="n">COCODetectionDataset</span><span class="p">(</span><span class="n">data_dir</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">data_dir</span><span class="p">,</span>
  734. <span class="n">name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_subdir</span><span class="p">,</span>
  735. <span class="n">json_file</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_json_file</span><span class="p">,</span>
  736. <span class="n">img_size</span><span class="o">=</span><span class="n">train_input_dim</span><span class="p">,</span>
  737. <span class="n">cache</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">cache_train_images</span><span class="p">,</span>
  738. <span class="n">cache_dir_path</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">cache_dir_path</span><span class="p">,</span>
  739. <span class="n">transforms</span><span class="o">=</span><span class="n">train_transforms</span><span class="p">,</span>
  740. <span class="n">with_crowd</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
  741. <span class="n">val_input_dim</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_image_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_image_size</span><span class="p">)</span>
  742. <span class="n">with_crowd</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="s1">&#39;with_crowd&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  743. <span class="c1"># IF CACHE- CREATING THE CACHE FILE WILL HAPPEN ONLY FOR RANK 0, THEN ALL THE OTHER RANKS SIMPLY READ FROM IT.</span>
  744. <span class="k">with</span> <span class="n">wait_for_the_master</span><span class="p">(</span><span class="n">local_rank</span><span class="p">):</span>
  745. <span class="bp">self</span><span class="o">.</span><span class="n">valset</span> <span class="o">=</span> <span class="n">COCODetectionDataset</span><span class="p">(</span>
  746. <span class="n">data_dir</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">data_dir</span><span class="p">,</span>
  747. <span class="n">json_file</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_json_file</span><span class="p">,</span>
  748. <span class="n">name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_subdir</span><span class="p">,</span>
  749. <span class="n">img_size</span><span class="o">=</span><span class="n">val_input_dim</span><span class="p">,</span>
  750. <span class="n">transforms</span><span class="o">=</span><span class="p">[</span><span class="n">DetectionPaddedRescale</span><span class="p">(</span><span class="n">input_dim</span><span class="o">=</span><span class="n">val_input_dim</span><span class="p">),</span>
  751. <span class="n">DetectionTargetsFormatTransform</span><span class="p">(</span><span class="n">max_targets</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">output_format</span><span class="o">=</span><span class="n">targets_format</span><span class="p">)],</span>
  752. <span class="n">cache</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">cache_val_images</span><span class="p">,</span>
  753. <span class="n">cache_dir_path</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">cache_dir_path</span><span class="p">,</span>
  754. <span class="n">with_crowd</span><span class="o">=</span><span class="n">with_crowd</span><span class="p">)</span>
  755. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">COCO_DETECTION_CLASSES_LIST</span></div>
  756. </pre></div>
  757. </div>
  758. </div>
  759. <footer>
  760. <hr/>
  761. <div role="contentinfo">
  762. <p>&#169; Copyright 2021, SuperGradients team.</p>
  763. </div>
  764. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  765. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  766. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  767. </footer>
  768. </div>
  769. </div>
  770. </section>
  771. </div>
  772. <script>
  773. jQuery(function () {
  774. SphinxRtdTheme.Navigation.enable(true);
  775. });
  776. </script>
  777. </body>
  778. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.datasets.datasets_utils &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.datasets.datasets_utils</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.datasets.datasets_utils</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">copy</span>
  84. <span class="kn">import</span> <span class="nn">os</span>
  85. <span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">ABC</span><span class="p">,</span> <span class="n">abstractmethod</span>
  86. <span class="kn">from</span> <span class="nn">multiprocessing</span> <span class="kn">import</span> <span class="n">Value</span><span class="p">,</span> <span class="n">Lock</span>
  87. <span class="kn">import</span> <span class="nn">random</span>
  88. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span>
  89. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  90. <span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
  91. <span class="kn">import</span> <span class="nn">torchvision</span>
  92. <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
  93. <span class="kn">import</span> <span class="nn">torch</span>
  94. <span class="kn">import</span> <span class="nn">torch.distributed</span> <span class="k">as</span> <span class="nn">dist</span>
  95. <span class="kn">from</span> <span class="nn">super_gradients.common.sg_loggers.abstract_sg_logger</span> <span class="kn">import</span> <span class="n">AbstractSGLogger</span>
  96. <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
  97. <span class="kn">from</span> <span class="nn">deprecated</span> <span class="kn">import</span> <span class="n">deprecated</span>
  98. <span class="kn">from</span> <span class="nn">matplotlib.patches</span> <span class="kn">import</span> <span class="n">Rectangle</span>
  99. <span class="kn">from</span> <span class="nn">torchvision.datasets</span> <span class="kn">import</span> <span class="n">ImageFolder</span>
  100. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.auto_augment</span> <span class="kn">import</span> <span class="n">rand_augment_transform</span>
  101. <span class="kn">from</span> <span class="nn">torchvision.transforms</span> <span class="kn">import</span> <span class="n">transforms</span><span class="p">,</span> <span class="n">InterpolationMode</span><span class="p">,</span> <span class="n">RandomResizedCrop</span>
  102. <span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
  103. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.utils</span> <span class="kn">import</span> <span class="n">AverageMeter</span>
  104. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">DetectionVisualization</span><span class="p">,</span> <span class="n">Anchors</span>
  105. <span class="kn">import</span> <span class="nn">uuid</span>
  106. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.distributed_training_utils</span> <span class="kn">import</span> <span class="n">get_local_rank</span><span class="p">,</span> <span class="n">get_world_size</span>
  107. <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
  108. <div class="viewcode-block" id="get_mean_and_std_torch"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.get_mean_and_std_torch">[docs]</a><span class="k">def</span> <span class="nf">get_mean_and_std_torch</span><span class="p">(</span><span class="n">data_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dataloader</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">RandomResizeSize</span><span class="o">=</span><span class="mi">224</span><span class="p">):</span>
  109. <span class="sd">&quot;&quot;&quot;</span>
  110. <span class="sd"> A function for getting the mean and std of large datasets using pytorch dataloader and gpu functionality.</span>
  111. <span class="sd"> :param data_dir: String, path to none-library dataset folder. For example &quot;/data/Imagenette&quot; or &quot;/data/TinyImagenet&quot;</span>
  112. <span class="sd"> :param dataloader: a torch DataLoader, as it would feed the data into the trainer (including transforms etc).</span>
  113. <span class="sd"> :param RandomResizeSize: Int, the size of the RandomResizeCrop as it appears in the DataInterface (for example, for Imagenet,</span>
  114. <span class="sd"> this value should be 224).</span>
  115. <span class="sd"> :return: 2 lists,mean and std, each one of len 3 (1 for each channel)</span>
  116. <span class="sd"> &quot;&quot;&quot;</span>
  117. <span class="k">assert</span> <span class="n">data_dir</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="n">dataloader</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">,</span> <span class="s1">&#39;Please provide either path to data folder or DataLoader, not both.&#39;</span>
  118. <span class="k">if</span> <span class="n">dataloader</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  119. <span class="n">traindir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">data_dir</span><span class="p">),</span> <span class="s1">&#39;train&#39;</span><span class="p">)</span>
  120. <span class="n">trainset</span> <span class="o">=</span> <span class="n">ImageFolder</span><span class="p">(</span><span class="n">traindir</span><span class="p">,</span> <span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span><span class="n">transforms</span><span class="o">.</span><span class="n">RandomResizedCrop</span><span class="p">(</span><span class="n">RandomResizeSize</span><span class="p">),</span>
  121. <span class="n">transforms</span><span class="o">.</span><span class="n">RandomHorizontalFlip</span><span class="p">(),</span>
  122. <span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">()]))</span>
  123. <span class="n">dataloader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">trainset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">)</span>
  124. <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Calculating on </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">dataloader</span><span class="o">.</span><span class="n">dataset</span><span class="o">.</span><span class="n">targets</span><span class="p">)</span><span class="si">}</span><span class="s1"> Training Samples&#39;</span><span class="p">)</span>
  125. <span class="n">device</span> <span class="o">=</span> <span class="s1">&#39;cuda:0&#39;</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="s1">&#39;cpu&#39;</span>
  126. <span class="n">h</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span>
  127. <span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">dataloader</span><span class="p">):</span>
  128. <span class="n">inputs</span> <span class="o">=</span> <span class="n">inputs</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
  129. <span class="k">if</span> <span class="n">batch_idx</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  130. <span class="n">h</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="n">inputs</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">inputs</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
  131. <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Min: </span><span class="si">{</span><span class="n">inputs</span><span class="o">.</span><span class="n">min</span><span class="p">()</span><span class="si">}</span><span class="s1">, Max: </span><span class="si">{</span><span class="n">inputs</span><span class="o">.</span><span class="n">max</span><span class="p">()</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
  132. <span class="n">chsum</span> <span class="o">=</span> <span class="n">inputs</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  133. <span class="k">else</span><span class="p">:</span>
  134. <span class="n">chsum</span> <span class="o">+=</span> <span class="n">inputs</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  135. <span class="n">mean</span> <span class="o">=</span> <span class="n">chsum</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">trainset</span><span class="p">)</span> <span class="o">/</span> <span class="n">h</span> <span class="o">/</span> <span class="n">w</span>
  136. <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;mean: </span><span class="si">{</span><span class="n">mean</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
  137. <span class="n">chsum</span> <span class="o">=</span> <span class="kc">None</span>
  138. <span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">dataloader</span><span class="p">):</span>
  139. <span class="n">inputs</span> <span class="o">=</span> <span class="n">inputs</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
  140. <span class="k">if</span> <span class="n">batch_idx</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  141. <span class="n">chsum</span> <span class="o">=</span> <span class="p">(</span><span class="n">inputs</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span><span class="o">.</span><span class="n">pow</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  142. <span class="k">else</span><span class="p">:</span>
  143. <span class="n">chsum</span> <span class="o">+=</span> <span class="p">(</span><span class="n">inputs</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span><span class="o">.</span><span class="n">pow</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  144. <span class="n">std</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">chsum</span> <span class="o">/</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">trainset</span><span class="p">)</span> <span class="o">*</span> <span class="n">h</span> <span class="o">*</span> <span class="n">w</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))</span>
  145. <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;std: </span><span class="si">{</span><span class="n">std</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
  146. <span class="k">return</span> <span class="n">mean</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span> <span class="n">std</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span></div>
  147. <div class="viewcode-block" id="get_mean_and_std"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.get_mean_and_std">[docs]</a><span class="nd">@deprecated</span><span class="p">(</span><span class="n">reason</span><span class="o">=</span><span class="s1">&#39;Use get_mean_and_std_torch() instead. It is faster and more accurate&#39;</span><span class="p">)</span>
  148. <span class="k">def</span> <span class="nf">get_mean_and_std</span><span class="p">(</span><span class="n">dataset</span><span class="p">):</span>
  149. <span class="sd">&#39;&#39;&#39;Compute the mean and std value of dataset.&#39;&#39;&#39;</span>
  150. <span class="n">dataloader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
  151. <span class="n">mean</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
  152. <span class="n">std</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
  153. <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;==&gt; Computing mean and std..&#39;</span><span class="p">)</span>
  154. <span class="n">j</span> <span class="o">=</span> <span class="mi">0</span>
  155. <span class="k">for</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span> <span class="ow">in</span> <span class="n">dataloader</span><span class="p">:</span>
  156. <span class="k">if</span> <span class="n">j</span> <span class="o">%</span> <span class="mi">10</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  157. <span class="nb">print</span><span class="p">(</span><span class="n">j</span><span class="p">)</span>
  158. <span class="n">j</span> <span class="o">+=</span> <span class="mi">1</span>
  159. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">3</span><span class="p">):</span>
  160. <span class="n">mean</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+=</span> <span class="n">inputs</span><span class="p">[:,</span> <span class="n">i</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
  161. <span class="n">std</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+=</span> <span class="n">inputs</span><span class="p">[:,</span> <span class="n">i</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span><span class="o">.</span><span class="n">std</span><span class="p">()</span>
  162. <span class="n">mean</span><span class="o">.</span><span class="n">div_</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">dataset</span><span class="p">))</span>
  163. <span class="n">std</span><span class="o">.</span><span class="n">div_</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">dataset</span><span class="p">))</span>
  164. <span class="k">return</span> <span class="n">mean</span><span class="p">,</span> <span class="n">std</span></div>
  165. <div class="viewcode-block" id="AbstractCollateFunction"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.AbstractCollateFunction">[docs]</a><span class="k">class</span> <span class="nc">AbstractCollateFunction</span><span class="p">(</span><span class="n">ABC</span><span class="p">):</span>
  166. <span class="sd">&quot;&quot;&quot;</span>
  167. <span class="sd"> A collate function (for torch DataLoader)</span>
  168. <span class="sd"> &quot;&quot;&quot;</span>
  169. <span class="nd">@abstractmethod</span>
  170. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
  171. <span class="k">pass</span></div>
  172. <div class="viewcode-block" id="ComposedCollateFunction"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.ComposedCollateFunction">[docs]</a><span class="k">class</span> <span class="nc">ComposedCollateFunction</span><span class="p">(</span><span class="n">AbstractCollateFunction</span><span class="p">):</span>
  173. <span class="sd">&quot;&quot;&quot;</span>
  174. <span class="sd"> A function (for torch DataLoader) which executes a sequence of sub collate functions</span>
  175. <span class="sd"> &quot;&quot;&quot;</span>
  176. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">functions</span><span class="p">:</span> <span class="nb">list</span><span class="p">):</span>
  177. <span class="bp">self</span><span class="o">.</span><span class="n">functions</span> <span class="o">=</span> <span class="n">functions</span>
  178. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
  179. <span class="k">for</span> <span class="n">f</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">functions</span><span class="p">:</span>
  180. <span class="n">batch</span> <span class="o">=</span> <span class="n">f</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
  181. <span class="k">return</span> <span class="n">batch</span></div>
  182. <div class="viewcode-block" id="AtomicInteger"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.AtomicInteger">[docs]</a><span class="k">class</span> <span class="nc">AtomicInteger</span><span class="p">:</span>
  183. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span>
  184. <span class="bp">self</span><span class="o">.</span><span class="n">_value</span> <span class="o">=</span> <span class="n">Value</span><span class="p">(</span><span class="s1">&#39;i&#39;</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
  185. <span class="k">def</span> <span class="fm">__set__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">instance</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span>
  186. <span class="bp">self</span><span class="o">.</span><span class="n">_value</span><span class="o">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">value</span>
  187. <span class="k">def</span> <span class="fm">__get__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">instance</span><span class="p">,</span> <span class="n">owner</span><span class="p">):</span>
  188. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_value</span><span class="o">.</span><span class="n">value</span></div>
  189. <div class="viewcode-block" id="MultiScaleCollateFunction"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.MultiScaleCollateFunction">[docs]</a><span class="k">class</span> <span class="nc">MultiScaleCollateFunction</span><span class="p">(</span><span class="n">AbstractCollateFunction</span><span class="p">):</span>
  190. <span class="sd">&quot;&quot;&quot;</span>
  191. <span class="sd"> a collate function to implement multi-scale data augmentation</span>
  192. <span class="sd"> according to https://arxiv.org/pdf/1612.08242.pdf</span>
  193. <span class="sd"> &quot;&quot;&quot;</span>
  194. <span class="n">_counter</span> <span class="o">=</span> <span class="n">AtomicInteger</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
  195. <span class="n">_current_size</span> <span class="o">=</span> <span class="n">AtomicInteger</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
  196. <span class="n">_lock</span> <span class="o">=</span> <span class="n">Lock</span><span class="p">()</span>
  197. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">target_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">min_image_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">max_image_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
  198. <span class="n">image_size_steps</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
  199. <span class="n">change_frequency</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">):</span>
  200. <span class="sd">&quot;&quot;&quot;</span>
  201. <span class="sd"> set parameters for the multi-scale collate function</span>
  202. <span class="sd"> the possible image sizes are in range [min_image_size, max_image_size] in steps of image_size_steps</span>
  203. <span class="sd"> a new size will be randomly selected every change_frequency calls to the collate_fn()</span>
  204. <span class="sd"> :param target_size: scales will be [0.66 * target_size, 1.5 * target_size]</span>
  205. <span class="sd"> :param min_image_size: the minimum size to scale down to (in pixels)</span>
  206. <span class="sd"> :param max_image_size: the maximum size to scale up to (in pixels)</span>
  207. <span class="sd"> :param image_size_steps: typically, the stride of the net, which defines the possible image</span>
  208. <span class="sd"> size multiplications</span>
  209. <span class="sd"> :param change_frequency:</span>
  210. <span class="sd"> &quot;&quot;&quot;</span>
  211. <span class="k">assert</span> <span class="n">target_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">or</span> <span class="p">(</span><span class="n">max_image_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">min_image_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">),</span> \
  212. <span class="s1">&#39;either target_size or min_image_size and max_image_size has to be set&#39;</span>
  213. <span class="k">assert</span> <span class="n">target_size</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="n">max_image_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">,</span> <span class="s1">&#39;target_size and max_image_size cannot be both defined&#39;</span>
  214. <span class="k">if</span> <span class="n">target_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  215. <span class="n">min_image_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="mf">0.66</span> <span class="o">*</span> <span class="n">target_size</span> <span class="o">-</span> <span class="p">((</span><span class="mf">0.66</span> <span class="o">*</span> <span class="n">target_size</span><span class="p">)</span> <span class="o">%</span> <span class="n">image_size_steps</span><span class="p">)</span> <span class="o">+</span> <span class="n">image_size_steps</span><span class="p">)</span>
  216. <span class="n">max_image_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="mf">1.5</span> <span class="o">*</span> <span class="n">target_size</span> <span class="o">-</span> <span class="p">((</span><span class="mf">1.5</span> <span class="o">*</span> <span class="n">target_size</span><span class="p">)</span> <span class="o">%</span> <span class="n">image_size_steps</span><span class="p">))</span>
  217. <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Using multi-scale </span><span class="si">%g</span><span class="s1"> - </span><span class="si">%g</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="p">(</span><span class="n">min_image_size</span><span class="p">,</span> <span class="n">max_image_size</span><span class="p">))</span>
  218. <span class="bp">self</span><span class="o">.</span><span class="n">sizes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">min_image_size</span><span class="p">,</span> <span class="n">max_image_size</span> <span class="o">+</span> <span class="n">image_size_steps</span><span class="p">,</span> <span class="n">image_size_steps</span><span class="p">)</span>
  219. <span class="bp">self</span><span class="o">.</span><span class="n">image_size_steps</span> <span class="o">=</span> <span class="n">image_size_steps</span>
  220. <span class="bp">self</span><span class="o">.</span><span class="n">frequency</span> <span class="o">=</span> <span class="n">change_frequency</span>
  221. <span class="bp">self</span><span class="o">.</span><span class="n">_current_size</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sizes</span><span class="p">)</span>
  222. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span>
  223. <span class="k">with</span> <span class="bp">self</span><span class="o">.</span><span class="n">_lock</span><span class="p">:</span>
  224. <span class="c1"># Important: this implementation was tailored for a specific input. it assumes the batch is a tuple where</span>
  225. <span class="c1"># the images are the first item</span>
  226. <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">),</span> <span class="s1">&#39;this collate function expects the input to be a tuple (images, labels)&#39;</span>
  227. <span class="n">images</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
  228. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_counter</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">frequency</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  229. <span class="bp">self</span><span class="o">.</span><span class="n">_current_size</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sizes</span><span class="p">)</span>
  230. <span class="bp">self</span><span class="o">.</span><span class="n">_counter</span> <span class="o">+=</span> <span class="mi">1</span>
  231. <span class="k">assert</span> <span class="n">images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_size_steps</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_size_steps</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> \
  232. <span class="s1">&#39;images sized not divisible by </span><span class="si">%d</span><span class="s1">. (resize images before calling multi_scale)&#39;</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_size_steps</span>
  233. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_current_size</span> <span class="o">!=</span> <span class="nb">max</span><span class="p">(</span><span class="n">images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">:]):</span>
  234. <span class="n">ratio</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_current_size</span><span class="p">)</span> <span class="o">/</span> <span class="nb">max</span><span class="p">(</span><span class="n">images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">:])</span>
  235. <span class="n">new_size</span> <span class="o">=</span> <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="nb">round</span><span class="p">(</span><span class="n">images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="n">ratio</span><span class="p">)),</span> <span class="nb">int</span><span class="p">(</span><span class="nb">round</span><span class="p">(</span><span class="n">images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">*</span> <span class="n">ratio</span><span class="p">)))</span>
  236. <span class="n">images</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">interpolate</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">new_size</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s1">&#39;bilinear&#39;</span><span class="p">,</span> <span class="n">align_corners</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
  237. <span class="k">return</span> <span class="n">images</span><span class="p">,</span> <span class="n">batch</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></div>
  238. <div class="viewcode-block" id="AbstractPrePredictionCallback"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.AbstractPrePredictionCallback">[docs]</a><span class="k">class</span> <span class="nc">AbstractPrePredictionCallback</span><span class="p">(</span><span class="n">ABC</span><span class="p">):</span>
  239. <span class="sd">&quot;&quot;&quot;</span>
  240. <span class="sd"> Abstract class for forward pass preprocessing function, to be used by passing its inheritors through training_params</span>
  241. <span class="sd"> pre_prediction_callback keyword arg.</span>
  242. <span class="sd"> Should implement __call__ and return images, targets after applying the desired preprocessing.</span>
  243. <span class="sd"> &quot;&quot;&quot;</span>
  244. <span class="nd">@abstractmethod</span>
  245. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">):</span>
  246. <span class="k">pass</span></div>
  247. <div class="viewcode-block" id="MultiscalePrePredictionCallback"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.MultiscalePrePredictionCallback">[docs]</a><span class="k">class</span> <span class="nc">MultiscalePrePredictionCallback</span><span class="p">(</span><span class="n">AbstractPrePredictionCallback</span><span class="p">):</span>
  248. <span class="sd">&quot;&quot;&quot;</span>
  249. <span class="sd"> Mutiscale pre-prediction callback pass function.</span>
  250. <span class="sd"> When passed through train_params images, targets will be applied by the below transform to support multi scaling</span>
  251. <span class="sd"> on the fly.</span>
  252. <span class="sd"> After each self.frequency forward passes, change size randomly from</span>
  253. <span class="sd"> (input_size-self.multiscale_range*self.image_size_steps, input_size-(self.multiscale_range-1)*self.image_size_steps,</span>
  254. <span class="sd"> ...input_size+self.multiscale_range*self.image_size_steps)</span>
  255. <span class="sd"> Attributes:</span>
  256. <span class="sd"> multiscale_range: (int) Range of values for resize sizes as discussed above (default=5)</span>
  257. <span class="sd"> image_size_steps: (int) Image step sizes as discussed abov (default=32)</span>
  258. <span class="sd"> change_frequency: (int) The frequency to apply change in input size.</span>
  259. <span class="sd"> &quot;&quot;&quot;</span>
  260. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">multiscale_range</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span>
  261. <span class="n">image_size_steps</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
  262. <span class="n">change_frequency</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">):</span>
  263. <span class="bp">self</span><span class="o">.</span><span class="n">multiscale_range</span> <span class="o">=</span> <span class="n">multiscale_range</span>
  264. <span class="bp">self</span><span class="o">.</span><span class="n">image_size_steps</span> <span class="o">=</span> <span class="n">image_size_steps</span>
  265. <span class="bp">self</span><span class="o">.</span><span class="n">frequency</span> <span class="o">=</span> <span class="n">change_frequency</span>
  266. <span class="bp">self</span><span class="o">.</span><span class="n">rank</span> <span class="o">=</span> <span class="kc">None</span>
  267. <span class="bp">self</span><span class="o">.</span><span class="n">is_distributed</span> <span class="o">=</span> <span class="kc">None</span>
  268. <span class="bp">self</span><span class="o">.</span><span class="n">sampled_imres_once</span> <span class="o">=</span> <span class="kc">False</span>
  269. <span class="bp">self</span><span class="o">.</span><span class="n">new_input_size</span> <span class="o">=</span> <span class="kc">None</span>
  270. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">):</span>
  271. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">rank</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  272. <span class="bp">self</span><span class="o">.</span><span class="n">rank</span> <span class="o">=</span> <span class="n">get_local_rank</span><span class="p">()</span>
  273. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_distributed</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  274. <span class="bp">self</span><span class="o">.</span><span class="n">is_distributed</span> <span class="o">=</span> <span class="n">get_world_size</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mi">1</span>
  275. <span class="c1"># GENERATE A NEW SIZE AND BROADCAST IT TO THE THE OTHER RANKS SO THEY HAVE THE SAME SCALE</span>
  276. <span class="n">input_size</span> <span class="o">=</span> <span class="n">inputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">:]</span>
  277. <span class="k">if</span> <span class="n">batch_idx</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">frequency</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  278. <span class="n">tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">inputs</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
  279. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  280. <span class="n">size_factor</span> <span class="o">=</span> <span class="n">input_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">input_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
  281. <span class="n">min_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">input_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_size_steps</span><span class="p">)</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">multiscale_range</span>
  282. <span class="n">max_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">input_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_size_steps</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">multiscale_range</span>
  283. <span class="n">random_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">min_size</span><span class="p">,</span> <span class="n">max_size</span><span class="p">)</span>
  284. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampled_imres_once</span><span class="p">:</span>
  285. <span class="n">size</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="o">*</span><span class="n">random_size</span><span class="p">)</span>
  286. <span class="k">else</span><span class="p">:</span>
  287. <span class="c1"># sample the biggest resolution first to make sure the run fits into the GPU memory</span>
  288. <span class="n">size</span> <span class="o">=</span> <span class="n">max_size</span>
  289. <span class="bp">self</span><span class="o">.</span><span class="n">sampled_imres_once</span> <span class="o">=</span> <span class="kc">True</span>
  290. <span class="n">size</span> <span class="o">=</span> <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">image_size_steps</span> <span class="o">*</span> <span class="n">size</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_size_steps</span> <span class="o">*</span> <span class="nb">int</span><span class="p">(</span><span class="n">size</span> <span class="o">*</span> <span class="n">size_factor</span><span class="p">))</span>
  291. <span class="n">tensor</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
  292. <span class="n">tensor</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
  293. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_distributed</span><span class="p">:</span>
  294. <span class="n">dist</span><span class="o">.</span><span class="n">barrier</span><span class="p">()</span>
  295. <span class="n">dist</span><span class="o">.</span><span class="n">broadcast</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
  296. <span class="bp">self</span><span class="o">.</span><span class="n">new_input_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">tensor</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span> <span class="n">tensor</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
  297. <span class="n">scale_y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">new_input_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">input_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
  298. <span class="n">scale_x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">new_input_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">input_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
  299. <span class="k">if</span> <span class="n">scale_x</span> <span class="o">!=</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">scale_y</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">:</span>
  300. <span class="n">inputs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">interpolate</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">new_input_size</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;bilinear&quot;</span><span class="p">,</span> <span class="n">align_corners</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
  301. <span class="k">return</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span></div>
  302. <div class="viewcode-block" id="DetectionMultiscalePrePredictionCallback"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.DetectionMultiscalePrePredictionCallback">[docs]</a><span class="k">class</span> <span class="nc">DetectionMultiscalePrePredictionCallback</span><span class="p">(</span><span class="n">MultiscalePrePredictionCallback</span><span class="p">):</span>
  303. <span class="sd">&quot;&quot;&quot;</span>
  304. <span class="sd"> Mutiscalepre-prediction callback for object detection.</span>
  305. <span class="sd"> When passed through train_params images, targets will be applied by the below transform to support multi scaling</span>
  306. <span class="sd"> on the fly.</span>
  307. <span class="sd"> After each self.frequency forward passes, change size randomly from</span>
  308. <span class="sd"> (input_size-self.multiscale_range*self.image_size_steps, input_size-(self.multiscale_range-1)*self.image_size_steps,</span>
  309. <span class="sd"> ...input_size+self.multiscale_range*self.image_size_steps) and apply the same rescaling to the box coordinates.</span>
  310. <span class="sd"> Attributes:</span>
  311. <span class="sd"> multiscale_range: (int) Range of values for resize sizes as discussed above (default=5)</span>
  312. <span class="sd"> image_size_steps: (int) Image step sizes as discussed abov (default=32)</span>
  313. <span class="sd"> change_frequency: (int) The frequency to apply change in input size.</span>
  314. <span class="sd"> &quot;&quot;&quot;</span>
  315. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">):</span>
  316. <span class="c1"># RESCALE THE IMAGE FIRST WITH SUPER(), AND IF RESCALING HAS ACTUALLY BEEN DONE APPLY TO BOXES AS WELL</span>
  317. <span class="n">input_size</span> <span class="o">=</span> <span class="n">inputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">:]</span>
  318. <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">DetectionMultiscalePrePredictionCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__call__</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">)</span>
  319. <span class="n">new_input_size</span> <span class="o">=</span> <span class="n">inputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">:]</span>
  320. <span class="n">scale_y</span> <span class="o">=</span> <span class="n">new_input_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">input_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
  321. <span class="n">scale_x</span> <span class="o">=</span> <span class="n">new_input_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">input_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
  322. <span class="k">if</span> <span class="n">scale_x</span> <span class="o">!=</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">scale_y</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">:</span>
  323. <span class="n">targets</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">2</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">2</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="n">scale_x</span>
  324. <span class="n">targets</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">3</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">3</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="n">scale_y</span>
  325. <span class="k">return</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span></div>
  326. <span class="n">_pil_interpolation_to_str</span> <span class="o">=</span> <span class="p">{</span>
  327. <span class="n">Image</span><span class="o">.</span><span class="n">NEAREST</span><span class="p">:</span> <span class="s1">&#39;PIL.Image.NEAREST&#39;</span><span class="p">,</span>
  328. <span class="n">Image</span><span class="o">.</span><span class="n">BILINEAR</span><span class="p">:</span> <span class="s1">&#39;PIL.Image.BILINEAR&#39;</span><span class="p">,</span>
  329. <span class="n">Image</span><span class="o">.</span><span class="n">BICUBIC</span><span class="p">:</span> <span class="s1">&#39;PIL.Image.BICUBIC&#39;</span><span class="p">,</span>
  330. <span class="n">Image</span><span class="o">.</span><span class="n">LANCZOS</span><span class="p">:</span> <span class="s1">&#39;PIL.Image.LANCZOS&#39;</span><span class="p">,</span>
  331. <span class="n">Image</span><span class="o">.</span><span class="n">HAMMING</span><span class="p">:</span> <span class="s1">&#39;PIL.Image.HAMMING&#39;</span><span class="p">,</span>
  332. <span class="n">Image</span><span class="o">.</span><span class="n">BOX</span><span class="p">:</span> <span class="s1">&#39;PIL.Image.BOX&#39;</span><span class="p">,</span>
  333. <span class="p">}</span>
  334. <span class="k">def</span> <span class="nf">_pil_interp</span><span class="p">(</span><span class="n">method</span><span class="p">):</span>
  335. <span class="k">if</span> <span class="n">method</span> <span class="o">==</span> <span class="s1">&#39;bicubic&#39;</span><span class="p">:</span>
  336. <span class="k">return</span> <span class="n">InterpolationMode</span><span class="o">.</span><span class="n">BICUBIC</span>
  337. <span class="k">elif</span> <span class="n">method</span> <span class="o">==</span> <span class="s1">&#39;lanczos&#39;</span><span class="p">:</span>
  338. <span class="k">return</span> <span class="n">InterpolationMode</span><span class="o">.</span><span class="n">LANCZOS</span>
  339. <span class="k">elif</span> <span class="n">method</span> <span class="o">==</span> <span class="s1">&#39;hamming&#39;</span><span class="p">:</span>
  340. <span class="k">return</span> <span class="n">InterpolationMode</span><span class="o">.</span><span class="n">HAMMING</span>
  341. <span class="k">elif</span> <span class="n">method</span> <span class="o">==</span> <span class="s1">&#39;nearest&#39;</span><span class="p">:</span>
  342. <span class="k">return</span> <span class="n">InterpolationMode</span><span class="o">.</span><span class="n">NEAREST</span>
  343. <span class="k">elif</span> <span class="n">method</span> <span class="o">==</span> <span class="s1">&#39;bilinear&#39;</span><span class="p">:</span>
  344. <span class="k">return</span> <span class="n">InterpolationMode</span><span class="o">.</span><span class="n">BILINEAR</span>
  345. <span class="k">elif</span> <span class="n">method</span> <span class="o">==</span> <span class="s1">&#39;box&#39;</span><span class="p">:</span>
  346. <span class="k">return</span> <span class="n">InterpolationMode</span><span class="o">.</span><span class="n">BOX</span>
  347. <span class="k">else</span><span class="p">:</span>
  348. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;interpolation type must be one of [&#39;bilinear&#39;, &#39;bicubic&#39;, &#39;lanczos&#39;, &#39;hamming&#39;, &quot;</span>
  349. <span class="s2">&quot;&#39;nearest&#39;, &#39;box&#39;] for explicit interpolation type, or &#39;random&#39; for random&quot;</span><span class="p">)</span>
  350. <span class="n">_RANDOM_INTERPOLATION</span> <span class="o">=</span> <span class="p">(</span><span class="n">InterpolationMode</span><span class="o">.</span><span class="n">BILINEAR</span><span class="p">,</span> <span class="n">InterpolationMode</span><span class="o">.</span><span class="n">BICUBIC</span><span class="p">)</span>
  351. <div class="viewcode-block" id="RandomResizedCropAndInterpolation"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.RandomResizedCropAndInterpolation">[docs]</a><span class="k">class</span> <span class="nc">RandomResizedCropAndInterpolation</span><span class="p">(</span><span class="n">RandomResizedCrop</span><span class="p">):</span>
  352. <span class="sd">&quot;&quot;&quot;</span>
  353. <span class="sd"> Crop the given PIL Image to random size and aspect ratio with explicitly chosen or random interpolation.</span>
  354. <span class="sd"> A crop of random size (default: of 0.08 to 1.0) of the original size and a random</span>
  355. <span class="sd"> aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop</span>
  356. <span class="sd"> is finally resized to given size.</span>
  357. <span class="sd"> This is popularly used to train the Inception networks.</span>
  358. <span class="sd"> Args:</span>
  359. <span class="sd"> size: expected output size of each edge</span>
  360. <span class="sd"> scale: range of size of the origin size cropped</span>
  361. <span class="sd"> ratio: range of aspect ratio of the origin aspect ratio cropped</span>
  362. <span class="sd"> interpolation: Default: PIL.Image.BILINEAR</span>
  363. <span class="sd"> &quot;&quot;&quot;</span>
  364. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">size</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="p">(</span><span class="mf">0.08</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">),</span> <span class="n">ratio</span><span class="o">=</span><span class="p">(</span><span class="mf">3.</span> <span class="o">/</span> <span class="mf">4.</span><span class="p">,</span> <span class="mf">4.</span> <span class="o">/</span> <span class="mf">3.</span><span class="p">),</span>
  365. <span class="n">interpolation</span><span class="o">=</span><span class="s1">&#39;default&#39;</span><span class="p">):</span>
  366. <span class="nb">super</span><span class="p">(</span><span class="n">RandomResizedCropAndInterpolation</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="n">size</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="n">scale</span><span class="p">,</span> <span class="n">ratio</span><span class="o">=</span><span class="n">ratio</span><span class="p">,</span> <span class="n">interpolation</span><span class="o">=</span><span class="n">interpolation</span><span class="p">)</span>
  367. <span class="k">if</span> <span class="n">interpolation</span> <span class="o">==</span> <span class="s1">&#39;random&#39;</span><span class="p">:</span>
  368. <span class="bp">self</span><span class="o">.</span><span class="n">interpolation</span> <span class="o">=</span> <span class="n">_RANDOM_INTERPOLATION</span>
  369. <span class="k">elif</span> <span class="n">interpolation</span> <span class="o">==</span> <span class="s1">&#39;default&#39;</span><span class="p">:</span>
  370. <span class="bp">self</span><span class="o">.</span><span class="n">interpolation</span> <span class="o">=</span> <span class="n">InterpolationMode</span><span class="o">.</span><span class="n">BILINEAR</span>
  371. <span class="k">else</span><span class="p">:</span>
  372. <span class="bp">self</span><span class="o">.</span><span class="n">interpolation</span> <span class="o">=</span> <span class="n">_pil_interp</span><span class="p">(</span><span class="n">interpolation</span><span class="p">)</span>
  373. <div class="viewcode-block" id="RandomResizedCropAndInterpolation.forward"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.RandomResizedCropAndInterpolation.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">img</span><span class="p">):</span>
  374. <span class="sd">&quot;&quot;&quot;</span>
  375. <span class="sd"> Args:</span>
  376. <span class="sd"> img (PIL Image): Image to be cropped and resized.</span>
  377. <span class="sd"> Returns:</span>
  378. <span class="sd"> PIL Image: Randomly cropped and resized image.</span>
  379. <span class="sd"> &quot;&quot;&quot;</span>
  380. <span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_params</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ratio</span><span class="p">)</span>
  381. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">interpolation</span><span class="p">,</span> <span class="p">(</span><span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">)):</span>
  382. <span class="n">interpolation</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">interpolation</span><span class="p">)</span>
  383. <span class="k">else</span><span class="p">:</span>
  384. <span class="n">interpolation</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">interpolation</span>
  385. <span class="k">return</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">resized_crop</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">size</span><span class="p">,</span> <span class="n">interpolation</span><span class="p">)</span></div>
  386. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  387. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">interpolation</span><span class="p">,</span> <span class="p">(</span><span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">)):</span>
  388. <span class="n">interpolate_str</span> <span class="o">=</span> <span class="s1">&#39; &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="n">_pil_interpolation_to_str</span><span class="p">[</span><span class="n">x</span><span class="p">]</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">interpolation</span><span class="p">])</span>
  389. <span class="k">else</span><span class="p">:</span>
  390. <span class="n">interpolate_str</span> <span class="o">=</span> <span class="n">_pil_interpolation_to_str</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">interpolation</span><span class="p">]</span>
  391. <span class="n">format_string</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span> <span class="o">+</span> <span class="s1">&#39;(size=</span><span class="si">{0}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">size</span><span class="p">)</span>
  392. <span class="n">format_string</span> <span class="o">+=</span> <span class="s1">&#39;, scale=</span><span class="si">{0}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="nb">round</span><span class="p">(</span><span class="n">s</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="p">))</span>
  393. <span class="n">format_string</span> <span class="o">+=</span> <span class="s1">&#39;, ratio=</span><span class="si">{0}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">tuple</span><span class="p">(</span><span class="nb">round</span><span class="p">(</span><span class="n">r</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">ratio</span><span class="p">))</span>
  394. <span class="n">format_string</span> <span class="o">+=</span> <span class="s1">&#39;, interpolation=</span><span class="si">{0}</span><span class="s1">)&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">interpolate_str</span><span class="p">)</span>
  395. <span class="k">return</span> <span class="n">format_string</span></div>
  396. <span class="n">STAT_LOGGER_FONT_SIZE</span> <span class="o">=</span> <span class="mi">15</span>
  397. <div class="viewcode-block" id="DatasetStatisticsTensorboardLogger"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.DatasetStatisticsTensorboardLogger">[docs]</a><span class="k">class</span> <span class="nc">DatasetStatisticsTensorboardLogger</span><span class="p">:</span>
  398. <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
  399. <span class="n">DEFAULT_SUMMARY_PARAMS</span> <span class="o">=</span> <span class="p">{</span>
  400. <span class="s1">&#39;sample_images&#39;</span><span class="p">:</span> <span class="mi">32</span><span class="p">,</span> <span class="c1"># by default, 32 images will be sampled from each dataset</span>
  401. <span class="s1">&#39;plot_class_distribution&#39;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
  402. <span class="s1">&#39;plot_box_size_distribution&#39;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
  403. <span class="s1">&#39;plot_anchors_coverage&#39;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
  404. <span class="s1">&#39;max_batches&#39;</span><span class="p">:</span> <span class="mi">30</span>
  405. <span class="p">}</span>
  406. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sg_logger</span><span class="p">:</span> <span class="n">AbstractSGLogger</span><span class="p">,</span> <span class="n">summary_params</span><span class="p">:</span> <span class="nb">dict</span> <span class="o">=</span> <span class="n">DEFAULT_SUMMARY_PARAMS</span><span class="p">):</span>
  407. <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span> <span class="o">=</span> <span class="n">sg_logger</span>
  408. <span class="bp">self</span><span class="o">.</span><span class="n">summary_params</span> <span class="o">=</span> <span class="p">{</span><span class="o">**</span><span class="n">DatasetStatisticsTensorboardLogger</span><span class="o">.</span><span class="n">DEFAULT_SUMMARY_PARAMS</span><span class="p">,</span> <span class="o">**</span><span class="n">summary_params</span><span class="p">}</span>
  409. <div class="viewcode-block" id="DatasetStatisticsTensorboardLogger.analyze"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.DatasetStatisticsTensorboardLogger.analyze">[docs]</a> <span class="k">def</span> <span class="nf">analyze</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data_loader</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">,</span> <span class="n">title</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
  410. <span class="n">all_classes</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="n">anchors</span><span class="p">:</span> <span class="nb">list</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  411. <span class="sd">&quot;&quot;&quot;</span>
  412. <span class="sd"> :param data_loader: the dataset data loader</span>
  413. <span class="sd"> :param dataset_params: the dataset parameters</span>
  414. <span class="sd"> :param title: the title for this dataset (i.e. Coco 2017 test set)</span>
  415. <span class="sd"> :param anchors: the list of anchors used by the model. applicable only for detection datasets</span>
  416. <span class="sd"> :param all_classes: the list of all classes names</span>
  417. <span class="sd"> &quot;&quot;&quot;</span>
  418. <span class="c1"># FIXME: UNCOMMENT AND APPLY TO NEW DetectionDataSet ONCE ITS MERGED</span>
  419. <span class="c1"># if isinstance(data_loader.dataset, DetectionDataSet):</span>
  420. <span class="c1"># self._analyze_detection(data_loader=data_loader, title=title,</span>
  421. <span class="c1"># all_classes=all_classes, anchors=anchors)</span>
  422. <span class="c1"># else:</span>
  423. <span class="c1"># DatasetStatisticsTensorboardLogger.logger.warning(&#39;only DetectionDataSet are currently supported&#39;)</span>
  424. <span class="n">DatasetStatisticsTensorboardLogger</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s1">&#39;only DetectionDataSet are currently supported&#39;</span><span class="p">)</span></div>
  425. <span class="k">def</span> <span class="nf">_analyze_detection</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data_loader</span><span class="p">,</span> <span class="n">title</span><span class="p">,</span> <span class="n">all_classes</span><span class="p">,</span> <span class="n">anchors</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  426. <span class="sd">&quot;&quot;&quot;</span>
  427. <span class="sd"> Analyze a detection dataset</span>
  428. <span class="sd"> :param data_loader: the dataset data loader</span>
  429. <span class="sd"> :param dataset_params: the dataset parameters</span>
  430. <span class="sd"> :param all_classes: the list of all classes names</span>
  431. <span class="sd"> :param title: the title for this dataset (i.e. Coco 2017 test set)</span>
  432. <span class="sd"> :param anchors: the list of anchors used by the model. if not provided, anchors coverage will not be analyzed</span>
  433. <span class="sd"> &quot;&quot;&quot;</span>
  434. <span class="k">try</span><span class="p">:</span>
  435. <span class="n">color_mean</span> <span class="o">=</span> <span class="n">AverageMeter</span><span class="p">()</span>
  436. <span class="n">color_std</span> <span class="o">=</span> <span class="n">AverageMeter</span><span class="p">()</span>
  437. <span class="n">all_labels</span> <span class="o">=</span> <span class="p">[]</span>
  438. <span class="n">image_size</span> <span class="o">=</span> <span class="mi">0</span>
  439. <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">tqdm</span><span class="p">(</span><span class="n">data_loader</span><span class="p">)):</span>
  440. <span class="k">if</span> <span class="n">i</span> <span class="o">&gt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">summary_params</span><span class="p">[</span><span class="s1">&#39;max_batches&#39;</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
  441. <span class="k">break</span>
  442. <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  443. <span class="n">image_size</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">images</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">images</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span>
  444. <span class="k">if</span> <span class="n">images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">summary_params</span><span class="p">[</span><span class="s1">&#39;sample_images&#39;</span><span class="p">]:</span>
  445. <span class="n">samples</span> <span class="o">=</span> <span class="n">images</span><span class="p">[:</span><span class="bp">self</span><span class="o">.</span><span class="n">summary_params</span><span class="p">[</span><span class="s1">&#39;sample_images&#39;</span><span class="p">]]</span>
  446. <span class="k">else</span><span class="p">:</span>
  447. <span class="n">samples</span> <span class="o">=</span> <span class="n">images</span>
  448. <span class="n">pred</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">6</span><span class="p">))</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">samples</span><span class="p">))]</span>
  449. <span class="k">try</span><span class="p">:</span>
  450. <span class="n">result_images</span> <span class="o">=</span> <span class="n">DetectionVisualization</span><span class="o">.</span><span class="n">visualize_batch</span><span class="p">(</span><span class="n">image_tensor</span><span class="o">=</span><span class="n">samples</span><span class="p">,</span> <span class="n">pred_boxes</span><span class="o">=</span><span class="n">pred</span><span class="p">,</span>
  451. <span class="n">target_boxes</span><span class="o">=</span><span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">labels</span><span class="p">),</span>
  452. <span class="n">batch_name</span><span class="o">=</span><span class="n">title</span><span class="p">,</span>
  453. <span class="n">class_names</span><span class="o">=</span><span class="n">all_classes</span><span class="p">,</span>
  454. <span class="n">box_thickness</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
  455. <span class="n">gt_alpha</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
  456. <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_images</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">title</span><span class="si">}</span><span class="s1"> sample images&#39;</span><span class="p">,</span> <span class="n">images</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">result_images</span><span class="p">)</span>
  457. <span class="o">.</span><span class="n">transpose</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])[:,</span> <span class="p">::</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:])</span>
  458. <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
  459. <span class="n">DatasetStatisticsTensorboardLogger</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
  460. <span class="sa">f</span><span class="s1">&#39;Dataset Statistics failed at adding an example batch:</span><span class="se">\n</span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
  461. <span class="k">return</span>
  462. <span class="n">all_labels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span>
  463. <span class="n">color_mean</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">]),</span> <span class="mi">1</span><span class="p">)</span>
  464. <span class="n">color_std</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">]),</span> <span class="mi">1</span><span class="p">)</span>
  465. <span class="n">all_labels</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">all_labels</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)[</span><span class="mi">1</span><span class="p">:]</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
  466. <span class="k">try</span><span class="p">:</span>
  467. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">summary_params</span><span class="p">[</span><span class="s1">&#39;plot_class_distribution&#39;</span><span class="p">]:</span>
  468. <span class="bp">self</span><span class="o">.</span><span class="n">_analyze_class_distribution</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="n">all_labels</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">all_classes</span><span class="p">),</span> <span class="n">title</span><span class="o">=</span><span class="n">title</span><span class="p">)</span>
  469. <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
  470. <span class="n">DatasetStatisticsTensorboardLogger</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Dataset Statistics failed at analyzing class distributions.</span><span class="se">\n</span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
  471. <span class="k">return</span>
  472. <span class="k">try</span><span class="p">:</span>
  473. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">summary_params</span><span class="p">[</span><span class="s1">&#39;plot_box_size_distribution&#39;</span><span class="p">]:</span>
  474. <span class="bp">self</span><span class="o">.</span><span class="n">_analyze_object_size_distribution</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="n">all_labels</span><span class="p">,</span> <span class="n">title</span><span class="o">=</span><span class="n">title</span><span class="p">)</span>
  475. <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
  476. <span class="n">DatasetStatisticsTensorboardLogger</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Dataset Statistics failed at analyzing object size &#39;</span>
  477. <span class="sa">f</span><span class="s1">&#39;distributions.</span><span class="se">\n</span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
  478. <span class="k">return</span>
  479. <span class="n">summary</span> <span class="o">=</span> <span class="s1">&#39;&#39;</span>
  480. <span class="n">summary</span> <span class="o">+=</span> <span class="sa">f</span><span class="s1">&#39;dataset size: </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">data_loader</span><span class="p">)</span><span class="si">}</span><span class="s1"> </span><span class="se">\n</span><span class="s1">&#39;</span>
  481. <span class="n">summary</span> <span class="o">+=</span> <span class="sa">f</span><span class="s1">&#39;color mean: </span><span class="si">{</span><span class="n">color_mean</span><span class="o">.</span><span class="n">average</span><span class="si">}</span><span class="s1"> </span><span class="se">\n</span><span class="s1">&#39;</span>
  482. <span class="n">summary</span> <span class="o">+=</span> <span class="sa">f</span><span class="s1">&#39;color std: </span><span class="si">{</span><span class="n">color_std</span><span class="o">.</span><span class="n">average</span><span class="si">}</span><span class="s1"> </span><span class="se">\n</span><span class="s1">&#39;</span>
  483. <span class="k">try</span><span class="p">:</span>
  484. <span class="k">if</span> <span class="n">anchors</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">image_size</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
  485. <span class="n">coverage</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_analyze_anchors_coverage</span><span class="p">(</span><span class="n">anchors</span><span class="o">=</span><span class="n">anchors</span><span class="p">,</span> <span class="n">image_size</span><span class="o">=</span><span class="n">image_size</span><span class="p">,</span>
  486. <span class="n">title</span><span class="o">=</span><span class="n">title</span><span class="p">,</span> <span class="n">labels</span><span class="o">=</span><span class="n">all_labels</span><span class="p">)</span>
  487. <span class="n">summary</span> <span class="o">+=</span> <span class="sa">f</span><span class="s1">&#39;anchors: </span><span class="si">{</span><span class="n">anchors</span><span class="si">}</span><span class="s1"> </span><span class="se">\n</span><span class="s1">&#39;</span>
  488. <span class="n">summary</span> <span class="o">+=</span> <span class="sa">f</span><span class="s1">&#39;anchors coverage: </span><span class="si">{</span><span class="n">coverage</span><span class="si">}</span><span class="s1"> </span><span class="se">\n</span><span class="s1">&#39;</span>
  489. <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
  490. <span class="n">DatasetStatisticsTensorboardLogger</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Dataset Statistics failed at analyzing anchors &#39;</span>
  491. <span class="sa">f</span><span class="s1">&#39;coverage.</span><span class="se">\n</span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
  492. <span class="k">return</span>
  493. <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_text</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">title</span><span class="si">}</span><span class="s1"> Statistics&#39;</span><span class="p">,</span> <span class="n">text_string</span><span class="o">=</span><span class="n">summary</span><span class="p">)</span>
  494. <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
  495. <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
  496. <span class="n">DatasetStatisticsTensorboardLogger</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;dataset analysis failed!</span><span class="se">\n</span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
  497. <span class="k">def</span> <span class="nf">_analyze_class_distribution</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">labels</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">title</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
  498. <span class="n">hist</span><span class="p">,</span> <span class="n">edges</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">labels</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">num_classes</span><span class="p">)</span>
  499. <span class="n">f</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">[</span><span class="mi">10</span><span class="p">,</span> <span class="mi">8</span><span class="p">])</span>
  500. <span class="n">plt</span><span class="o">.</span><span class="n">bar</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">num_classes</span><span class="p">),</span> <span class="n">hist</span><span class="p">,</span> <span class="n">width</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s1">&#39;#0504aa&#39;</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.7</span><span class="p">)</span>
  501. <span class="n">plt</span><span class="o">.</span><span class="n">xlim</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span>
  502. <span class="n">plt</span><span class="o">.</span><span class="n">grid</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="s1">&#39;y&#39;</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.75</span><span class="p">)</span>
  503. <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s1">&#39;Value&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
  504. <span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s1">&#39;Frequency&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
  505. <span class="n">plt</span><span class="o">.</span><span class="n">xticks</span><span class="p">(</span><span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
  506. <span class="n">plt</span><span class="o">.</span><span class="n">yticks</span><span class="p">(</span><span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
  507. <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">title</span><span class="si">}</span><span class="s1"> class distribution&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
  508. <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_figure</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">title</span><span class="si">}</span><span class="s2"> class distribution&quot;</span><span class="p">,</span> <span class="n">figure</span><span class="o">=</span><span class="n">f</span><span class="p">)</span>
  509. <span class="n">text_dist</span> <span class="o">=</span> <span class="s1">&#39;&#39;</span>
  510. <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">hist</span><span class="p">):</span>
  511. <span class="n">text_dist</span> <span class="o">+=</span> <span class="sa">f</span><span class="s1">&#39;[</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s1">]: </span><span class="si">{</span><span class="n">val</span><span class="si">}</span><span class="s1">, &#39;</span>
  512. <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_text</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">title</span><span class="si">}</span><span class="s2"> class distribution&quot;</span><span class="p">,</span> <span class="n">text_string</span><span class="o">=</span><span class="n">text_dist</span><span class="p">)</span>
  513. <span class="k">def</span> <span class="nf">_analyze_object_size_distribution</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">labels</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">title</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
  514. <span class="sd">&quot;&quot;&quot;</span>
  515. <span class="sd"> This function will add two plots to the tensorboard.</span>
  516. <span class="sd"> one is a 2D histogram and the other is a scatter plot. in both cases the X axis is the object width and Y axis</span>
  517. <span class="sd"> is the object width (both normalized by image size)</span>
  518. <span class="sd"> :param labels: all the labels of the dataset of the shape [class_label, x_center, y_center, w, h]</span>
  519. <span class="sd"> :param title: the dataset title</span>
  520. <span class="sd"> &quot;&quot;&quot;</span>
  521. <span class="c1"># histogram plot</span>
  522. <span class="n">hist</span><span class="p">,</span> <span class="n">xedges</span><span class="p">,</span> <span class="n">yedges</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram2d</span><span class="p">(</span><span class="n">labels</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">],</span> <span class="n">labels</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">],</span> <span class="mi">50</span><span class="p">)</span> <span class="c1"># x and y are deliberately switched</span>
  523. <span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">6</span><span class="p">))</span>
  524. <span class="n">fig</span><span class="o">.</span><span class="n">suptitle</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">title</span><span class="si">}</span><span class="s1"> boxes w/h distribution&#39;</span><span class="p">)</span>
  525. <span class="n">ax</span> <span class="o">=</span> <span class="n">fig</span><span class="o">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="mi">121</span><span class="p">)</span>
  526. <span class="n">ax</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s1">&#39;W&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
  527. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s1">&#39;H&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
  528. <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">hist</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span> <span class="n">interpolation</span><span class="o">=</span><span class="s1">&#39;nearest&#39;</span><span class="p">,</span> <span class="n">origin</span><span class="o">=</span><span class="s1">&#39;lower&#39;</span><span class="p">,</span>
  529. <span class="n">extent</span><span class="o">=</span><span class="p">[</span><span class="n">xedges</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">xedges</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">yedges</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">yedges</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]])</span>
  530. <span class="c1"># scatter plot</span>
  531. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">10000</span><span class="p">:</span>
  532. <span class="c1"># we randomly sample just 10000 objects so that the scatter plot will not get too dense</span>
  533. <span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">10000</span><span class="p">)]</span>
  534. <span class="n">ax</span> <span class="o">=</span> <span class="n">fig</span><span class="o">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="mi">122</span><span class="p">)</span>
  535. <span class="n">ax</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s1">&#39;W&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
  536. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s1">&#39;H&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
  537. <span class="n">plt</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">labels</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">],</span> <span class="n">labels</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">],</span> <span class="n">marker</span><span class="o">=</span><span class="s1">&#39;.&#39;</span><span class="p">)</span>
  538. <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_figure</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">title</span><span class="si">}</span><span class="s1"> boxes w/h distribution&#39;</span><span class="p">,</span> <span class="n">figure</span><span class="o">=</span><span class="n">fig</span><span class="p">)</span>
  539. <span class="nd">@staticmethod</span>
  540. <span class="k">def</span> <span class="nf">_get_rect</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">):</span>
  541. <span class="n">min_w</span> <span class="o">=</span> <span class="n">w</span> <span class="o">/</span> <span class="mf">4.0</span>
  542. <span class="n">min_h</span> <span class="o">=</span> <span class="n">h</span> <span class="o">/</span> <span class="mf">4.0</span>
  543. <span class="k">return</span> <span class="n">Rectangle</span><span class="p">((</span><span class="n">min_w</span><span class="p">,</span> <span class="n">min_h</span><span class="p">),</span> <span class="n">w</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">-</span> <span class="n">min_w</span><span class="p">,</span> <span class="n">h</span> <span class="o">*</span> <span class="mi">4</span> <span class="o">-</span> <span class="n">min_h</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s1">&#39;b&#39;</span><span class="p">,</span> <span class="n">facecolor</span><span class="o">=</span><span class="s1">&#39;none&#39;</span><span class="p">)</span>
  544. <span class="nd">@staticmethod</span>
  545. <span class="k">def</span> <span class="nf">_get_score</span><span class="p">(</span><span class="n">anchors</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">points</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">image_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
  546. <span class="sd">&quot;&quot;&quot;</span>
  547. <span class="sd"> Calculate the ratio (and 1/ratio) between each anchor width and height and each point (representing a possible</span>
  548. <span class="sd"> object width and height).</span>
  549. <span class="sd"> i.e. for an anchor with w=10,h=20 the point w=11,h=25 will have the ratios 11/10=1.1 and 25/20=1.25</span>
  550. <span class="sd"> or 10/11=0.91 and 20/25=0.8 respectively</span>
  551. <span class="sd"> :param anchors: array of anchors of the shape [2,N]</span>
  552. <span class="sd"> :param points: array of points of the shape [2,M]</span>
  553. <span class="sd"> :param image_size the size of the input image</span>
  554. <span class="sd"> :returns: an array of size [image_size - 1, image_size - 1] where each cell i,j represent the minimum ratio</span>
  555. <span class="sd"> for that cell (point) from all anchors</span>
  556. <span class="sd"> &quot;&quot;&quot;</span>
  557. <span class="n">ratio</span> <span class="o">=</span> <span class="n">anchors</span><span class="p">[:,</span> <span class="p">:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">/</span> <span class="n">points</span><span class="p">[:,</span> <span class="p">]</span>
  558. <span class="n">inv_ratio</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">ratio</span>
  559. <span class="n">min_ratio</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">np</span><span class="o">.</span><span class="n">minimum</span><span class="p">(</span><span class="n">ratio</span><span class="p">,</span> <span class="n">inv_ratio</span><span class="p">)</span>
  560. <span class="n">min_ratio</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">min_ratio</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
  561. <span class="n">to_closest_anchor</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">min_ratio</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
  562. <span class="n">to_closest_anchor</span><span class="p">[</span><span class="n">to_closest_anchor</span> <span class="o">&gt;</span> <span class="mf">0.75</span><span class="p">]</span> <span class="o">=</span> <span class="mi">2</span>
  563. <span class="k">return</span> <span class="n">to_closest_anchor</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">image_size</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
  564. <span class="k">def</span> <span class="nf">_analyze_anchors_coverage</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">anchors</span><span class="p">:</span> <span class="n">Anchors</span><span class="p">,</span> <span class="n">image_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">labels</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">title</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
  565. <span class="sd">&quot;&quot;&quot;</span>
  566. <span class="sd"> This function will add anchors coverage plots to the tensorboard.</span>
  567. <span class="sd"> :param anchors: a list of anchors</span>
  568. <span class="sd"> :param image_size: the input image size for this training</span>
  569. <span class="sd"> :param labels: all the labels of the dataset of the shape [class_label, x_center, y_center, w, h]</span>
  570. <span class="sd"> :param title: the dataset title</span>
  571. <span class="sd"> &quot;&quot;&quot;</span>
  572. <span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
  573. <span class="n">fig</span><span class="o">.</span><span class="n">suptitle</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">title</span><span class="si">}</span><span class="s1"> anchors coverage&#39;</span><span class="p">)</span>
  574. <span class="c1"># box style plot</span>
  575. <span class="n">ax</span> <span class="o">=</span> <span class="n">fig</span><span class="o">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="mi">121</span><span class="p">)</span>
  576. <span class="n">ax</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s1">&#39;W&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
  577. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s1">&#39;H&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
  578. <span class="n">ax</span><span class="o">.</span><span class="n">set_xlim</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="n">image_size</span><span class="p">])</span>
  579. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylim</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="n">image_size</span><span class="p">])</span>
  580. <span class="n">anchors_boxes</span> <span class="o">=</span> <span class="n">anchors</span><span class="o">.</span><span class="n">anchors</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
  581. <span class="n">anchors_len</span> <span class="o">=</span> <span class="n">anchors</span><span class="o">.</span><span class="n">num_anchors</span>
  582. <span class="n">anchors_boxes</span> <span class="o">=</span> <span class="n">anchors_boxes</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
  583. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">anchors_len</span><span class="p">):</span>
  584. <span class="n">rect</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_rect</span><span class="p">(</span><span class="n">anchors_boxes</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">],</span> <span class="n">anchors_boxes</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">])</span>
  585. <span class="n">rect</span><span class="o">.</span><span class="n">set_alpha</span><span class="p">(</span><span class="mf">0.3</span><span class="p">)</span>
  586. <span class="n">rect</span><span class="o">.</span><span class="n">set_facecolor</span><span class="p">([</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">(),</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">(),</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">(),</span> <span class="mf">0.3</span><span class="p">])</span>
  587. <span class="n">ax</span><span class="o">.</span><span class="n">add_patch</span><span class="p">(</span><span class="n">rect</span><span class="p">)</span>
  588. <span class="c1"># distance from anchor plot</span>
  589. <span class="n">ax</span> <span class="o">=</span> <span class="n">fig</span><span class="o">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="mi">122</span><span class="p">)</span>
  590. <span class="n">ax</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s1">&#39;W&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
  591. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s1">&#39;H&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
  592. <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
  593. <span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
  594. <span class="n">xx</span><span class="p">,</span> <span class="n">yy</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">meshgrid</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">sparse</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
  595. <span class="n">points</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">xx</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">yy</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)])</span>
  596. <span class="n">color</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_score</span><span class="p">(</span><span class="n">anchors_boxes</span><span class="p">,</span> <span class="n">points</span><span class="p">,</span> <span class="n">image_size</span><span class="p">)</span>
  597. <span class="n">ax</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s1">&#39;W&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
  598. <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s1">&#39;H&#39;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="n">STAT_LOGGER_FONT_SIZE</span><span class="p">)</span>
  599. <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">color</span><span class="p">,</span> <span class="n">interpolation</span><span class="o">=</span><span class="s1">&#39;nearest&#39;</span><span class="p">,</span> <span class="n">origin</span><span class="o">=</span><span class="s1">&#39;lower&#39;</span><span class="p">,</span>
  600. <span class="n">extent</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">image_size</span><span class="p">])</span>
  601. <span class="c1"># calculate the coverage for the dataset labels</span>
  602. <span class="n">cover_masks</span> <span class="o">=</span> <span class="p">[]</span>
  603. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">anchors_len</span><span class="p">):</span>
  604. <span class="n">w_max</span> <span class="o">=</span> <span class="p">(</span><span class="n">anchors_boxes</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">image_size</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span>
  605. <span class="n">w_min</span> <span class="o">=</span> <span class="p">(</span><span class="n">anchors_boxes</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">image_size</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.25</span>
  606. <span class="n">h_max</span> <span class="o">=</span> <span class="p">(</span><span class="n">anchors_boxes</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">image_size</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span>
  607. <span class="n">h_min</span> <span class="o">=</span> <span class="p">(</span><span class="n">anchors_boxes</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">image_size</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.25</span>
  608. <span class="n">cover_masks</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">logical_and</span><span class="p">(</span>
  609. <span class="n">np</span><span class="o">.</span><span class="n">logical_and</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">logical_and</span><span class="p">(</span><span class="n">labels</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">&lt;</span> <span class="n">w_max</span><span class="p">,</span> <span class="n">labels</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">w_min</span><span class="p">),</span> <span class="n">labels</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">&lt;</span> <span class="n">h_max</span><span class="p">),</span>
  610. <span class="n">labels</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">h_min</span><span class="p">))</span>
  611. <span class="n">cover_masks</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">cover_masks</span><span class="p">)</span>
  612. <span class="n">coverage</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">count_nonzero</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">cover_masks</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">))</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span>
  613. <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_figure</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">title</span><span class="si">}</span><span class="s1"> anchors coverage&#39;</span><span class="p">,</span> <span class="n">figure</span><span class="o">=</span><span class="n">fig</span><span class="p">)</span>
  614. <span class="k">return</span> <span class="n">coverage</span></div>
  615. <div class="viewcode-block" id="get_color_augmentation"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.get_color_augmentation">[docs]</a><span class="k">def</span> <span class="nf">get_color_augmentation</span><span class="p">(</span><span class="n">rand_augment_config_string</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">color_jitter</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span> <span class="n">crop_size</span><span class="o">=</span><span class="mi">224</span><span class="p">,</span> <span class="n">img_mean</span><span class="o">=</span><span class="p">[</span><span class="mf">0.485</span><span class="p">,</span> <span class="mf">0.456</span><span class="p">,</span> <span class="mf">0.406</span><span class="p">]):</span>
  616. <span class="sd">&quot;&quot;&quot;</span>
  617. <span class="sd"> Returns color augmentation class. As these augmentation cannot work on top one another, only one is returned</span>
  618. <span class="sd"> according to rand_augment_config_string</span>
  619. <span class="sd"> :param rand_augment_config_string: string which defines the auto augment configurations.</span>
  620. <span class="sd"> If none, color jitter will be returned. For possibile values see auto_augment.py</span>
  621. <span class="sd"> :param color_jitter: tuple for color jitter value.</span>
  622. <span class="sd"> :param crop_size: relevant only for auto augment</span>
  623. <span class="sd"> :param img_mean: relevant only for auto augment</span>
  624. <span class="sd"> :return: RandAugment transform or ColorJitter</span>
  625. <span class="sd"> &quot;&quot;&quot;</span>
  626. <span class="k">if</span> <span class="n">rand_augment_config_string</span><span class="p">:</span>
  627. <span class="n">auto_augment_params</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="n">translate_const</span><span class="o">=</span><span class="nb">int</span><span class="p">(</span><span class="n">crop_size</span> <span class="o">*</span> <span class="mf">0.45</span><span class="p">),</span>
  628. <span class="n">img_mean</span><span class="o">=</span><span class="nb">tuple</span><span class="p">([</span><span class="nb">min</span><span class="p">(</span><span class="mi">255</span><span class="p">,</span> <span class="nb">round</span><span class="p">(</span><span class="mi">255</span> <span class="o">*</span> <span class="n">x</span><span class="p">))</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">img_mean</span><span class="p">]))</span>
  629. <span class="n">color_augmentation</span> <span class="o">=</span> <span class="n">rand_augment_transform</span><span class="p">(</span><span class="n">rand_augment_config_string</span><span class="p">,</span> <span class="n">auto_augment_params</span><span class="p">)</span>
  630. <span class="k">else</span><span class="p">:</span> <span class="c1"># RandAugment includes colorjitter like augmentations, both cannot be applied together.</span>
  631. <span class="n">color_augmentation</span> <span class="o">=</span> <span class="n">transforms</span><span class="o">.</span><span class="n">ColorJitter</span><span class="p">(</span><span class="o">*</span><span class="n">color_jitter</span><span class="p">)</span>
  632. <span class="k">return</span> <span class="n">color_augmentation</span></div>
  633. <div class="viewcode-block" id="worker_init_reset_seed"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.datasets_utils.worker_init_reset_seed">[docs]</a><span class="k">def</span> <span class="nf">worker_init_reset_seed</span><span class="p">(</span><span class="n">worker_id</span><span class="p">):</span>
  634. <span class="sd">&quot;&quot;&quot;</span>
  635. <span class="sd"> Make sure each process has different random seed, especially for &#39;fork&#39; method.</span>
  636. <span class="sd"> Check https://github.com/pytorch/pytorch/issues/63311 for more details.</span>
  637. <span class="sd"> :param worker_id: placeholder (needs to be passed to DataLoader init).</span>
  638. <span class="sd"> &quot;&quot;&quot;</span>
  639. <span class="n">seed</span> <span class="o">=</span> <span class="n">uuid</span><span class="o">.</span><span class="n">uuid4</span><span class="p">()</span><span class="o">.</span><span class="n">int</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">**</span> <span class="mi">32</span>
  640. <span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span>
  641. <span class="n">torch</span><span class="o">.</span><span class="n">set_rng_state</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span><span class="o">.</span><span class="n">get_state</span><span class="p">())</span>
  642. <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span></div>
  643. </pre></div>
  644. </div>
  645. </div>
  646. <footer>
  647. <hr/>
  648. <div role="contentinfo">
  649. <p>&#169; Copyright 2021, SuperGradients team.</p>
  650. </div>
  651. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  652. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  653. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  654. </footer>
  655. </div>
  656. </div>
  657. </section>
  658. </div>
  659. <script>
  660. jQuery(function () {
  661. SphinxRtdTheme.Navigation.enable(true);
  662. });
  663. </script>
  664. </body>
  665. </html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.datasets.detection_datasets.coco_detection &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.datasets.detection_datasets.coco_detection &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
+        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
+        <script src="../../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -87,106 +89,97 @@
              
              
   <h1>Source code for super_gradients.training.datasets.detection_datasets.coco_detection</h1><div class="highlight"><pre>
   <h1>Source code for super_gradients.training.datasets.detection_datasets.coco_detection</h1><div class="highlight"><pre>
 <span></span><span class="kn">import</span> <span class="nn">os</span>
 <span></span><span class="kn">import</span> <span class="nn">os</span>
-<span class="kn">from</span> <span class="nn">pycocotools.coco</span> <span class="kn">import</span> <span class="n">COCO</span>
-<span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
-<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">Dataset</span>
-<span class="kn">import</span> <span class="nn">random</span>
 <span class="kn">import</span> <span class="nn">cv2</span>
 <span class="kn">import</span> <span class="nn">cv2</span>
+
 <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
 <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
-<span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
-<span class="kn">from</span> <span class="nn">multiprocessing.pool</span> <span class="kn">import</span> <span class="n">ThreadPool</span>
+<span class="kn">from</span> <span class="nn">pycocotools.coco</span> <span class="kn">import</span> <span class="n">COCO</span>
+
+<span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
+<span class="kn">from</span> <span class="nn">super_gradients.training.datasets.detection_datasets.detection_dataset</span> <span class="kn">import</span> <span class="n">DetectionDataset</span>
+<span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">DetectionTargetsFormat</span>
+<span class="kn">from</span> <span class="nn">super_gradients.training.datasets.datasets_conf</span> <span class="kn">import</span> <span class="n">COCO_DETECTION_CLASSES_LIST</span>
 
 
 <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
 <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
 
 
 
 
-<div class="viewcode-block" id="COCODetectionDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.COCODetectionDataset">[docs]</a><span class="k">class</span> <span class="nc">COCODetectionDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span>
-    <span class="sd">&quot;&quot;&quot;Detection dataset COCO implementation&quot;&quot;&quot;</span>
+<div class="viewcode-block" id="COCODetectionDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.COCODetectionDataset">[docs]</a><span class="k">class</span> <span class="nc">COCODetectionDataset</span><span class="p">(</span><span class="n">DetectionDataset</span><span class="p">):</span>
+    <span class="sd">&quot;&quot;&quot;Dataset for COCO object detection.&quot;&quot;&quot;</span>
 
 
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
-            <span class="bp">self</span><span class="p">,</span> <span class="n">img_size</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span>
-            <span class="n">data_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
-            <span class="n">json_file</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;instances_train2017.json&quot;</span><span class="p">,</span>
-            <span class="n">name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;images/train2017&quot;</span><span class="p">,</span>
-            <span class="n">cache</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
-            <span class="n">cache_dir_path</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
-            <span class="n">tight_box_rotation</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
-            <span class="n">transforms</span><span class="p">:</span> <span class="nb">list</span> <span class="o">=</span> <span class="p">[],</span>
-            <span class="n">with_crowd</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">json_file</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;instances_train2017.json&quot;</span><span class="p">,</span>
+        <span class="n">subdir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;images/train2017&quot;</span><span class="p">,</span>
+        <span class="n">tight_box_rotation</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+        <span class="n">with_crowd</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
+        <span class="o">*</span><span class="n">args</span><span class="p">,</span>
+        <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span>
     <span class="p">):</span>
     <span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
-<span class="sd">        :param img_size: tuple, Image size (when loaded, before transforms)</span>
-<span class="sd">        :param data_dir: str, root path to coco data.</span>
-<span class="sd">        :param json_file: str, path to coco json file, that resides in data_dir/annotations/json_file.</span>
-<span class="sd">        :param name: str, sub directory of data_dir containing the data.</span>
-<span class="sd">        :param cache: bool, whether to cache images</span>
-<span class="sd">        :param cache_dir_path: str, path to a directory that will be used for caching (with memmap).</span>
-<span class="sd">        :param tight_box_rotation: bool, whether to use of segmentation maps convex hull</span>
-<span class="sd">         as target_seg (see load_sample docs).</span>
-<span class="sd">        :param transforms: list of transforms to apply sequentially on sample in __getitem__</span>
+<span class="sd">        :param json_file:           Name of the coco json file, that resides in data_dir/annotations/json_file.</span>
+<span class="sd">        :param subdir:              Sub directory of data_dir containing the data.</span>
+<span class="sd">        :param tight_box_rotation:  bool, whether to use of segmentation maps convex hull as target_seg</span>
+<span class="sd">                                    (check get_sample docs).</span>
 <span class="sd">        :param with_crowd: Add the crowd groundtruths to __getitem__</span>
 <span class="sd">        :param with_crowd: Add the crowd groundtruths to __getitem__</span>
+
+<span class="sd">        kwargs:</span>
+<span class="sd">            all_classes_list: all classes list, default is COCO_DETECTION_CLASSES_LIST.</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
-        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">imgs</span> <span class="o">=</span> <span class="kc">None</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span> <span class="o">=</span> <span class="n">data_dir</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span> <span class="o">=</span> <span class="n">img_size</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span> <span class="o">=</span> <span class="n">data_dir</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">subdir</span> <span class="o">=</span> <span class="n">subdir</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">json_file</span> <span class="o">=</span> <span class="n">json_file</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">json_file</span> <span class="o">=</span> <span class="n">json_file</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">name</span> <span class="o">=</span> <span class="n">name</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">tight_box_rotation</span> <span class="o">=</span> <span class="n">tight_box_rotation</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">with_crowd</span> <span class="o">=</span> <span class="n">with_crowd</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">with_crowd</span> <span class="o">=</span> <span class="n">with_crowd</span>
 
 
+        <span class="n">target_fields</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">,</span> <span class="s2">&quot;crowd_target&quot;</span><span class="p">]</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">with_crowd</span> <span class="k">else</span> <span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span>
+        <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;target_fields&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">target_fields</span>
+        <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;output_fields&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">,</span> <span class="o">*</span><span class="n">target_fields</span><span class="p">]</span>
+        <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;original_target_format&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">DetectionTargetsFormat</span><span class="o">.</span><span class="n">XYXY_LABEL</span>
+        <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;all_classes_list&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;all_classes_list&quot;</span><span class="p">)</span> <span class="ow">or</span> <span class="n">COCO_DETECTION_CLASSES_LIST</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_setup_data_source</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
+        <span class="sd">&quot;&quot;&quot;Initialize img_and_target_path_list and warn if label file is missing</span>
+
+<span class="sd">        :return: List of tuples made of (img_path,target_path)</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">coco</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_init_coco</span><span class="p">()</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">class_ids</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">getCatIds</span><span class="p">())</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="nb">list</span><span class="p">([</span><span class="n">category</span><span class="p">[</span><span class="s2">&quot;name&quot;</span><span class="p">]</span> <span class="k">for</span> <span class="n">category</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">loadC
+        <span class="bp">self</span><span class="o">.</span><span class="n">sample_id_to_coco_id</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">getImgIds</span><span class="p">()</span>
+        <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sample_id_to_coco_id</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_init_coco</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">COCO</span><span class="p">:</span>
         <span class="n">annotation_file_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="p">,</span> <span class="s2">&quot;annotations&quot;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">json_file</span><span c
         <span class="n">annotation_file_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="p">,</span> <span class="s2">&quot;annotations&quot;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">json_file</span><span c
         <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">annotation_file_path</span><span class="p">):</span>
         <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">annotation_file_path</span><span class="p">):</span>
             <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Could not find annotation file under &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">annotation_file_path</span><span class="p">))</span>
             <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Could not find annotation file under &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">annotation_file_path</span><span class="p">))</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">coco</span> <span class="o">=</span> <span class="n">COCO</span><span class="p">(</span><span class="n">annotation_file_path</span><span class="p">)</span>  <span class="c1"># duplicate</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">tight_box_rotation</span> <span class="o">=</span> <span class="n">tight_box_rotation</span>
-        <span class="n">remove_useless_info</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tight_box_rotation</span><span class="p">)</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">getImgIds</span><span class="p">()</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">class_ids</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">getCatIds</span><span class="p">())</span>
-        <span class="n">cats</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">loadCats</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">getCatIds</span><span class="p">())</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">_classes</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">([</span><span class="n">c</span><span class="p">[</span><span class="s2">&quot;name&quot;</span><span class="p">]</span> <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="n">cats</span><span class="p">])</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">annotations</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_load_coco_annotations</span><span class="p">()</span>
-        <span class="k">if</span> <span class="n">cache</span><span class="p">:</span>  <span class="c1"># cache after merged</span>
-            <span class="k">if</span> <span class="n">cache_dir_path</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">cache_dir_path</span><span class="p">):</span>
-                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Must pass valid path through cache_dir_path when caching. Got &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">cache_dir_path</span><span class="p">))</span>
-            <span class="bp">self</span><span class="o">.</span><span class="n">cache_dir_path</span> <span class="o">=</span> <span class="n">cache_dir_path</span>
-            <span class="bp">self</span><span class="o">.</span><span class="n">_cache_images</span><span class="p">()</span>
-
-        <span class="bp">self</span><span class="o">.</span><span class="n">transforms</span> <span class="o">=</span> <span class="n">transforms</span>
-
-    <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">idx</span><span class="p">):</span>
-        <span class="n">sample</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">load_sample</span><span class="p">(</span><span class="n">idx</span><span class="p">)</span>
-        <span class="n">sample</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">apply_transforms</span><span class="p">(</span><span class="n">sample</span><span class="p">)</span>
-        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">with_crowd</span><span class="p">:</span>
-            <span class="k">return</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;crowd_target&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;info&quot;</s
-        <span class="k">else</span><span class="p">:</span>
-            <span class="k">return</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;info&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;id&quot;</span><span 
-
-    <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
-        <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ids</span><span class="p">)</span>
-
-    <span class="k">def</span> <span class="nf">_load_coco_annotations</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
-        <span class="k">return</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_load_anno_from_ids</span><span class="p">(</span><span class="n">_ids</span><span class="p">)</span> <span class="k">for</span> <span class="n">_ids</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ids</span><span class="p">,</span> <span class="n">desc</span><spa
-
-    <span class="k">def</span> <span class="nf">_load_anno_from_ids</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">id_</span><span class="p">):</span>
+
+        <span class="n">coco</span> <span class="o">=</span> <span class="n">COCO</span><span class="p">(</span><span class="n">annotation_file_path</span><span class="p">)</span>
+        <span class="n">remove_useless_info</span><span class="p">(</span><span class="n">coco</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">tight_box_rotation</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">coco</span>
+
+    <span class="k">def</span> <span class="nf">_load_annotation</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample_id</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span class="p">:</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
-<span class="sd">        Load relevant information of a specific image</span>
-
-<span class="sd">        :param id_: image id</span>
-<span class="sd">        :return res:            Target Bboxes (detection)</span>
-<span class="sd">        :return res_crowd:      Crowd target Bboxes (detection)</span>
-<span class="sd">        :return res_seg:        Segmentation</span>
-<span class="sd">        :return img_info:       Image (height, width)</span>
-<span class="sd">        :return resized_info:   Resides image (height, width)</span>
-<span class="sd">        :return img_path:       Path to the associated image</span>
+<span class="sd">        Load relevant information of a specific image.</span>
+
+<span class="sd">        :param sample_id:               Sample_id in the dataset</span>
+<span class="sd">        :return target:                 Target Bboxes (detection) in XYXY_LABEL format</span>
+<span class="sd">        :return crowd_target:           Crowd target Bboxes (detection) in XYXY_LABEL format</span>
+<span class="sd">        :return target_segmentation:    Segmentation</span>
+<span class="sd">        :return initial_img_shape:      Image (height, width)</span>
+<span class="sd">        :return resized_img_shape:      Resides image (height, width)</span>
+<span class="sd">        :return img_path:               Path to the associated image</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
-        <span class="n">im_ann</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">loadImgs</span><span class="p">(</span><span class="n">id_</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
-        <span class="n">width</span> <span class="o">=</span> <span class="n">im_ann</span><span class="p">[</span><span class="s2">&quot;width&quot;</span><span class="p">]</span>
-        <span class="n">height</span> <span class="o">=</span> <span class="n">im_ann</span><span class="p">[</span><span class="s2">&quot;height&quot;</span><span class="p">]</span>
-        <span class="n">anno_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">getAnnIds</span><span class="p">(</span><span class="n">imgIds</span><span class="o">=</span><span class="p">[</span><span class="nb">int</span><span class="p">(</span><span class="n">id_</span><span class="p">)])</span>
-        <span class="n">annotations</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">loadAnns</span><span class="p">(</span><span class="n">anno_ids</span><span class="p">)</span>
+
+        <span class="n">img_id</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sample_id_to_coco_id</span><span class="p">[</span><span class="n">sample_id</span><span class="p">]</span>
+
+        <span class="n">img_metadata</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">loadImgs</span><span class="p">(</span><span class="n">img_id</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
+        <span class="n">width</span> <span class="o">=</span> <span class="n">img_metadata</span><span class="p">[</span><span class="s2">&quot;width&quot;</span><span class="p">]</span>
+        <span class="n">height</span> <span class="o">=</span> <span class="n">img_metadata</span><span class="p">[</span><span class="s2">&quot;height&quot;</span><span class="p">]</span>
+
+        <span class="n">img_annotation_ids</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">getAnnIds</span><span class="p">(</span><span class="n">imgIds</span><span class="o">=</span><span class="p">[</span><span class="nb">int</span><span class="p">(</span><span class="n">img_id</span><span class="p">)])</span>
+        <span class="n">img_annotations</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">coco</span><span class="o">.</span><span class="n">loadAnns</span><span class="p">(</span><span class="n">img_annotation_ids</span><span class="p">)</span>
 
 
         <span class="n">cleaned_annotations</span> <span class="o">=</span> <span class="p">[]</span>
         <span class="n">cleaned_annotations</span> <span class="o">=</span> <span class="p">[]</span>
-        <span class="k">for</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="n">annotations</span><span class="p">:</span>
+        <span class="k">for</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="n">img_annotations</span><span class="p">:</span>
             <span class="n">x1</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;bbox&quot;</span><span class="p">][</span><span class="mi">0</span><span class="p">]))</span>
             <span class="n">x1</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;bbox&quot;</span><span class="p">][</span><span class="mi">0</span><span class="p">]))</span>
             <span class="n">y1</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;bbox&quot;</span><span class="p">][</span><span class="mi">1</span><span class="p">]))</span>
             <span class="n">y1</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;bbox&quot;</span><span class="p">][</span><span class="mi">1</span><span class="p">]))</span>
             <span class="n">x2</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">min</span><span class="p">((</span><span class="n">width</span><span class="p">,</span> <span class="n">x1</span> <span class="o">+</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;bb
             <span class="n">x2</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">min</span><span class="p">((</span><span class="n">width</span><span class="p">,</span> <span class="n">x1</span> <span class="o">+</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;bb
@@ -197,14 +190,14 @@
 
 
         <span class="n">non_crowd_annotations</span> <span class="o">=</span> <span class="p">[</span><span class="n">annotation</span> <span class="k">for</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="n">cleaned_annotations</span> <span class="k">if</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;iscrowd&quot;</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span><span class="p">]</span>
         <span class="n">non_crowd_annotations</span> <span class="o">=</span> <span class="p">[</span><span class="n">annotation</span> <span class="k">for</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="n">cleaned_annotations</span> <span class="k">if</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;iscrowd&quot;</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span><span class="p">]</span>
 
 
-        <span class="n">res</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">non_crowd_annotations</span><span class="p">),</span> <span class="mi">5</span><span class="p">))</span>
+        <span class="n">target</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">non_crowd_annotations</span><span class="p">),</span> <span class="mi">5</span><span class="p">))</span>
         <span class="n">num_seg_values</span> <span class="o">=</span> <span class="mi">98</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">tight_box_rotation</span> <span class="k">else</span> <span class="mi">0</span>
         <span class="n">num_seg_values</span> <span class="o">=</span> <span class="mi">98</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">tight_box_rotation</span> <span class="k">else</span> <span class="mi">0</span>
-        <span class="n">res_seg</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">non_crowd_annotations</span><span class="p">),</span> <span class="n">num_seg_values</span><span class="p">))</span>
-        <span class="n">res_seg</span><span class="o">.</span><span class="n">fill</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">)</span>
+        <span class="n">target_segmentation</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">non_crowd_annotations</span><span class="p">),</span> <span class="n">num_seg_values</span><span class="p">))</span>
+        <span class="n">target_segmentation</span><span class="o">.</span><span class="n">fill</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">nan</span><span class="p">)</span>
         <span class="k">for</span> <span class="n">ix</span><span class="p">,</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">non_crowd_annotations</span><span class="p">):</span>
         <span class="k">for</span> <span class="n">ix</span><span class="p">,</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">non_crowd_annotations</span><span class="p">):</span>
             <span class="bp">cls</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">class_ids</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;category_id&quot;</span><span class="p">])</span>
             <span class="bp">cls</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">class_ids</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;category_id&quot;</span><span class="p">])</span>
-            <span class="n">res</span><span class="p">[</span><span class="n">ix</span><span class="p">,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;clean_bbox&quot;</span><span class="p">]</span>
-            <span class="n">res</span><span class="p">[</span><span class="n">ix</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="bp">cls</span>
+            <span class="n">target</span><span class="p">[</span><span class="n">ix</span><span class="p">,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;clean_bbox&quot;</span><span class="p">]</span>
+            <span class="n">target</span><span class="p">[</span><span class="n">ix</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="bp">cls</span>
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">tight_box_rotation</span><span class="p">:</span>
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">tight_box_rotation</span><span class="p">:</span>
                 <span class="n">seg_points</span> <span class="o">=</span> <span class="p">[</span><span class="n">j</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">annotation</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;segmentation&quot;</span><span class="p">,</span> <span class="p">[])</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span clas
                 <span class="n">seg_points</span> <span class="o">=</span> <span class="p">[</span><span class="n">j</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">annotation</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;segmentation&quot;</span><span class="p">,</span> <span class="p">[])</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span clas
                 <span class="k">if</span> <span class="n">seg_points</span><span class="p">:</span>
                 <span class="k">if</span> <span class="n">seg_points</span><span class="p">:</span>
@@ -212,181 +205,41 @@
                     <span class="n">seg_points_convex</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">convexHull</span><span class="p">(</span><span class="n">seg_points_c</span><span class="p">)</span><span class="o">.</span><span class="n">ravel</span><span class="p">()</span>
                     <span class="n">seg_points_convex</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">convexHull</span><span class="p">(</span><span class="n">seg_points_c</span><span class="p">)</span><span class="o">.</span><span class="n">ravel</span><span class="p">()</span>
                 <span class="k">else</span><span class="p">:</span>
                 <span class="k">else</span><span class="p">:</span>
                     <span class="n">seg_points_convex</span> <span class="o">=</span> <span class="p">[]</span>
                     <span class="n">seg_points_convex</span> <span class="o">=</span> <span class="p">[]</span>
-                <span class="n">res_seg</span><span class="p">[</span><span class="n">ix</span><span class="p">,</span> <span class="p">:</span><span class="nb">len</span><span class="p">(</span><span class="n">seg_points_convex</span><span class="p">)]</span> <span class="o">=</span> <span class="n">seg_points_convex</span>
+                <span class="n">target_segmentation</span><span class="p">[</span><span class="n">ix</span><span class="p">,</span> <span class="p">:</span> <span class="nb">len</span><span class="p">(</span><span class="n">seg_points_convex</span><span class="p">)]</span> <span class="o">=</span> <span class="n">seg_points_convex</span>
 
 
         <span class="n">crowd_annotations</span> <span class="o">=</span> <span class="p">[</span><span class="n">annotation</span> <span class="k">for</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="n">cleaned_annotations</span> <span class="k">if</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;iscrowd&quot;</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">]</span>
         <span class="n">crowd_annotations</span> <span class="o">=</span> <span class="p">[</span><span class="n">annotation</span> <span class="k">for</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="n">cleaned_annotations</span> <span class="k">if</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;iscrowd&quot;</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">]</span>
 
 
-        <span class="n">res_crowd</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">crowd_annotations</span><span class="p">),</span> <span class="mi">5</span><span class="p">))</span>
+        <span class="n">crowd_target</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">crowd_annotations</span><span class="p">),</span> <span class="mi">5</span><span class="p">))</span>
         <span class="k">for</span> <span class="n">ix</span><span class="p">,</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">crowd_annotations</span><span class="p">):</span>
         <span class="k">for</span> <span class="n">ix</span><span class="p">,</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">crowd_annotations</span><span class="p">):</span>
             <span class="bp">cls</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">class_ids</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;category_id&quot;</span><span class="p">])</span>
             <span class="bp">cls</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">class_ids</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;category_id&quot;</span><span class="p">])</span>
-            <span class="n">res_crowd</span><span class="p">[</span><span class="n">ix</span><span class="p">,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;clean_bbox&quot;</span><span class="p">]</span>
-            <span class="n">res_crowd</span><span class="p">[</span><span class="n">ix</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="bp">cls</span>
+            <span class="n">crowd_target</span><span class="p">[</span><span class="n">ix</span><span class="p">,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;clean_bbox&quot;</span><span class="p">]</span>
+            <span class="n">crowd_target</span><span class="p">[</span><span class="n">ix</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="bp">cls</span>
 
 
         <span class="n">r</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">height</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</s
         <span class="n">r</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">height</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</s
-        <span class="n">res</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">*=</span> <span class="n">r</span>
-        <span class="n">res_crowd</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">*=</span> <span class="n">r</span>
-        <span class="n">res_seg</span> <span class="o">*=</span> <span class="n">r</span>
-
-        <span class="n">img_info</span> <span class="o">=</span> <span class="p">(</span><span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">)</span>
-        <span class="n">resized_info</span> <span class="o">=</span> <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">height</span> <span class="o">*</span> <span class="n">r</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">width</span> <span class="o">*</span> <span class="n">r</span><span class="p">))</span>
-
-        <span class="n">file_name</span> <span class="o">=</span> <span class="p">(</span>
-            <span class="n">im_ann</span><span class="p">[</span><span class="s2">&quot;file_name&quot;</span><span class="p">]</span>
-            <span class="k">if</span> <span class="s2">&quot;file_name&quot;</span> <span class="ow">in</span> <span class="n">im_ann</span>
-            <span class="k">else</span> <span class="s2">&quot;</span><span class="si">{:012}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">id_</span><span class="p">)</span> <span class="o">+</span> <span class="s2">&quot;.jpg&quot;</span>
-        <span class="p">)</span>
-        <span class="n">img_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">file_name</span><span class="p">)</span>
-        <span class="k">return</span> <span class="n">res</span><span class="p">,</span> <span class="n">res_crowd</span><span class="p">,</span> <span class="n">res_seg</span><span class="p">,</span> <span class="n">img_info</span><span class="p">,</span> <span class="n">resized_info</span><span class="p">,</span> <span class="n">img_path</span>
-
-    <span class="k">def</span> <span class="nf">_cache_images</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
-        <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
-            <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">********************************************************************************</span><span class="se">\n</span><span class="s2">&quot;</span>
-            <span class="s2">&quot;You are using cached images in RAM to accelerate training.</span><span class="se">\n</span><span class="s2">&quot;</span>
-            <span class="s2">&quot;This requires large system RAM.</span><span class="se">\n</span><span class="s2">&quot;</span>
-            <span class="s2">&quot;Make sure you have 200G+ RAM and 136G available disk space for training COCO.</span><span class="se">\n</span><span class="s2">&quot;</span>
-            <span class="s2">&quot;********************************************************************************</span><span class="se">\n</span><span class="s2">&quot;</span>
-        <span class="p">)</span>
-        <span class="n">max_h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
-        <span class="n">max_w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
-        <span class="n">cache_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cache_dir_path</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;img_resized_cache_</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">n
-        <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">cache_file</span><span class="p">):</span>
-            <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span>
-                <span class="s2">&quot;Caching images for the first time. This might take about 20 minutes for COCO&quot;</span>
-            <span class="p">)</span>
-            <span class="bp">self</span><span class="o">.</span><span class="n">imgs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">memmap</span><span class="p">(</span>
-                <span class="n">cache_file</span><span class="p">,</span>
-                <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ids</span><span class="p">),</span> <span class="n">max_h</span><span class="p">,</span> <span class="n">max_w</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span>
-                <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">,</span>
-                <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;w+&quot;</span><span class="p">,</span>
-            <span class="p">)</span>
-
-            <span class="n">NUM_THREADs</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="n">os</span><span class="o">.</span><span class="n">cpu_count</span><span class="p">())</span>
-            <span class="n">loaded_images</span> <span class="o">=</span> <span class="n">ThreadPool</span><span class="p">(</span><span class="n">NUM_THREADs</span><span class="p">)</span><span class="o">.</span><span class="n">imap</span><span class="p">(</span>
-                <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">load_resized_img</span><span class="p">(</span><span class="n">x</span><span class="p">),</span>
-                <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">annotations</span><span class="p">)),</span>
-            <span class="p">)</span>
-            <span class="n">pbar</span> <span class="o">=</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">enumerate</span><span class="p">(</span><span class="n">loaded_images</span><span class="p">),</span> <span class="n">total</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">annotations</span><span class="p">))</span>
-            <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">out</span> <span class="ow">in</span> <span class="n">pbar</span><span class="p">:</span>
-                <span class="bp">self</span><span class="o">.</span><span class="n">imgs</span><span class="p">[</span><span class="n">k</span><span class="p">][:</span> <span class="n">out</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="p">:</span> <span class="n">out</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span
-            <span class="bp">self</span><span class="o">.</span><span class="n">imgs</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span>
-            <span class="n">pbar</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
-        <span class="k">else</span><span class="p">:</span>
-            <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
-                <span class="s2">&quot;You are using cached imgs! Make sure your dataset is not changed!!</span><span class="se">\n</span><span class="s2">&quot;</span>
-                <span class="s2">&quot;Everytime the self.input_size is changed in your exp file, you need to delete</span><span class="se">\n</span><span class="s2">&quot;</span>
-                <span class="s2">&quot;the cached data and re-generate them.</span><span class="se">\n</span><span class="s2">&quot;</span>
-            <span class="p">)</span>
-
-        <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Loading cached imgs...&quot;</span><span class="p">)</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">imgs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">memmap</span><span class="p">(</span>
-            <span class="n">cache_file</span><span class="p">,</span>
-            <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ids</span><span class="p">),</span> <span class="n">max_h</span><span class="p">,</span> <span class="n">max_w</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span>
-            <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">,</span>
-            <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;r+&quot;</span><span class="p">,</span>
-        <span class="p">)</span>
-
-<div class="viewcode-block" id="COCODetectionDataset.load_resized_img"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.COCODetectionDataset.load_resized_img">[docs]</a>    <span class="k">def</span> <span class="nf">load_resized_img</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span>
-        <span class="sd">&quot;&quot;&quot;</span>
-<span class="sd">        Loads image at index, and resizes it to self.input_dim</span>
-
-<span class="sd">        :param index: index to load the image from</span>
-<span class="sd">        :return: resized_img</span>
-<span class="sd">        &quot;&quot;&quot;</span>
-        <span class="n">img</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">load_image</span><span class="p">(</span><span class="n">index</span><span class="p">)</span>
-        <span class="n">r</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><s
-        <span class="n">resized_img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span>
-            <span class="n">img</span><span class="p">,</span>
-            <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">r</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span c
-            <span class="n">interpolation</span><span class="o">=</span><span class="n">cv2</span><span class="o">.</span><span class="n">INTER_LINEAR</span><span class="p">,</span>
-        <span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
-        <span class="k">return</span> <span class="n">resized_img</span></div>
-
-<div class="viewcode-block" id="COCODetectionDataset.load_sample"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.COCODetectionDataset.load_sample">[docs]</a>    <span class="k">def</span> <span class="nf">load_sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span>
-        <span class="sd">&quot;&quot;&quot;</span>
-<span class="sd">        Loads sample at self.ids[index] as dictionary that holds:</span>
-<span class="sd">            &quot;image&quot;: Image resized to self.input_dim</span>
-<span class="sd">            &quot;target&quot;: Detection ground truth, np.array shaped (num_targets, 5), format is [class,x1,y1,x2,y2] with</span>
-<span class="sd">                image coordinates.</span>
-<span class="sd">            &quot;target_seg&quot;: Segmentation map convex hull derived detection target.</span>
-<span class="sd">            &quot;info&quot;: Original shape (height,width).</span>
-<span class="sd">            &quot;id&quot;: COCO image id</span>
-
-<span class="sd">        :param index: Sample index</span>
-<span class="sd">        :return: sample as described above</span>
-<span class="sd">        &quot;&quot;&quot;</span>
-        <span class="n">id_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ids</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
-        <span class="n">res</span><span class="p">,</span> <span class="n">res_crowd</span><span class="p">,</span> <span class="n">res_seg</span><span class="p">,</span> <span class="n">img_info</span><span class="p">,</span> <span class="n">resized_info</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">annotations</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
-        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">imgs</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
-            <span class="n">pad_img</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">imgs</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
-            <span class="n">img</span> <span class="o">=</span> <span class="n">pad_img</span><span class="p">[:</span> <span class="n">resized_info</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="p">:</span> <span class="n">resized_info</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="p">:]</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
-        <span class="k">else</span><span class="p">:</span>
-            <span class="n">img</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">load_resized_img</span><span class="p">(</span><span class="n">index</span><span class="p">)</span>
-
-        <span class="n">sample</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;image&quot;</span><span class="p">:</span> <span class="n">img</span><span class="p">,</span> <span class="s2">&quot;target&quot;</span><span class="p">:</span> <span class="n">res</span><span class="o">.</span><span class="n">copy</span><span class="p">(),</span> <span class="s2">&quot;target_seg&quot;</span><span class="p">:</span> <span class="n">res_seg</span><span class="p">,</span
-        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">with_crowd</span><span class="p">:</span>
-            <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;crowd_target&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">res_crowd</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
-        <span class="k">return</span> <span class="n">sample</span></div>
-
-<div class="viewcode-block" id="COCODetectionDataset.load_image"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.COCODetectionDataset.load_image">[docs]</a>    <span class="k">def</span> <span class="nf">load_image</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span>
-        <span class="sd">&quot;&quot;&quot;</span>
-<span class="sd">        Loads image at index with its original resolution</span>
-<span class="sd">        :param index: index in self.annotations</span>
-<span class="sd">        :return: image (np.array)</span>
-<span class="sd">        &quot;&quot;&quot;</span>
-        <span class="n">file_name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">annotations</span><span class="p">[</span><span class="n">index</span><span class="p">][</span><span class="mi">5</span><span class="p">]</span>
-
-        <span class="n">img_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">file_name</span><span class="p">)</span>
-
-        <span class="n">img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">imread</span><span class="p">(</span><span class="n">img_file</span><span class="p">)</span>
-        <span class="k">assert</span> <span class="n">img</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
-
-        <span class="k">return</span> <span class="n">img</span></div>
-
-    <span class="k">def</span> <span class="fm">__del__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
-        <span class="k">del</span> <span class="bp">self</span><span class="o">.</span><span class="n">imgs</span>
-
-    <span class="k">def</span> <span class="nf">_load_anno</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span>
-        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">annotations</span><span class="p">[</span><span class="n">index</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
-
-    <span class="k">def</span> <span class="nf">_get_random_non_empty_target_idx</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
-        <span class="n">target</span> <span class="o">=</span> <span class="p">[]</span>
-        <span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">target</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
-            <span class="n">idx</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ids</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
-            <span class="n">target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_load_anno</span><span class="p">(</span><span class="n">idx</span><span class="p">)</span>
-        <span class="k">return</span> <span class="n">idx</span>
-
-    <span class="k">def</span> <span class="nf">_load_random_samples</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">count</span><span class="p">,</span> <span class="n">non_empty_targets_only</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
-        <span class="n">inds</span> <span class="o">=</span> <span class="p">[</span>
-            <span class="bp">self</span><span class="o">.</span><span class="n">_get_random_non_empty_target_idx</span><span class="p">()</span> <span class="k">if</span> <span class="n">non_empty_targets_only</span> <span class="k">else</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.<
-            <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">count</span><span class="p">)]</span>
-        <span class="k">return</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">load_sample</span><span class="p">(</span><span class="n">ind</span><span class="p">)</span> <span class="k">for</span> <span class="n">ind</span> <span class="ow">in</span> <span class="n">inds</span><span class="p">]</span>
-
-    <span class="k">def</span> <span class="nf">_load_additional_inputs_for_transform</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">,</span> <span class="n">transform</span><span class="p">):</span>
-        <span class="n">additional_samples_count</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">additional_samples_count</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">transform</span><span class="p">,</span>
-                                                                                 <span class="s2">&quot;additional_samples_count&quot;</span><span class="p">)</span> <span class="k">else</span> <span class="mi">0</span>
-        <span class="n">non_empty_targets</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">non_empty_targets</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">transform</span><span class="p">,</span> <span class="s2">&quot;non_empty_targets&quot;</span><span class="p">)</span> <span class="k">else</span> <span class="kc">False</span>
-        <span class="n">additional_samples</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_load_random_samples</span><span class="p">(</span><span class="n">additional_samples_count</span><span class="p">,</span> <span class="n">non_empty_targets</span><span class="p">)</span>
-        <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;additional_samples&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">additional_samples</span>
-
-<div class="viewcode-block" id="COCODetectionDataset.apply_transforms"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.COCODetectionDataset.apply_transforms">[docs]</a>    <span class="k">def</span> <span class="nf">apply_transforms</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class
-        <span class="sd">&quot;&quot;&quot;</span>
-<span class="sd">        Applies self.transforms sequentially to sample</span>
-
-<span class="sd">        If a transforms has the attribute &#39;additional_samples_count&#39;, additional samples will be loaded and stored in</span>
-<span class="sd">         sample[&quot;additional_samples&quot;] prior to applying it. Combining with the attribute &quot;non_empty_targets&quot; will load</span>
-<span class="sd">         only additional samples with objects in them.</span>
-
-<span class="sd">        :param sample: Sample to apply the transforms on to (loaded with self.load_sample)</span>
-<span class="sd">        :return: Transformed sample</span>
-<span class="sd">        &quot;&quot;&quot;</span>
-        <span class="k">for</span> <span class="n">transform</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">transforms</span><span class="p">:</span>
-            <span class="bp">self</span><span class="o">.</span><span class="n">_load_additional_inputs_for_transform</span><span class="p">(</span><span class="n">sample</span><span class="p">,</span> <span class="n">transform</span><span class="p">)</span>
-            <span class="n">sample</span> <span class="o">=</span> <span class="n">transform</span><span class="p">(</span><span class="n">sample</span><span class="p">)</span>
-
-        <span class="k">return</span> <span class="n">sample</span></div></div>
-
-
-<div class="viewcode-block" id="remove_useless_info"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.remove_useless_info">[docs]</a><span class="k">def</span> <span class="nf">remove_useless_info</span><span class="p">(</span><span class="n">coco</span><span class="p">,</span> <span class="n">use_seg_info</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
+        <span class="n">target</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">*=</span> <span class="n">r</span>
+        <span class="n">crowd_target</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">*=</span> <span class="n">r</span>
+        <span class="n">target_segmentation</span> <span class="o">*=</span> <span class="n">r</span>
+
+        <span class="n">initial_img_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">)</span>
+        <span class="n">resized_img_shape</span> <span class="o">=</span> <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">height</span> <span class="o">*</span> <span class="n">r</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">width</span> <span class="o">*</span> <span class="n">r</span><span class="p">))</span>
+
+        <span class="n">file_name</span> <span class="o">=</span> <span class="n">img_metadata</span><span class="p">[</span><span class="s2">&quot;file_name&quot;</span><span class="p">]</span> <span class="k">if</span> <span class="s2">&quot;file_name&quot;</span> <span class="ow">in</span> <span class="n">img_metadata</span> <span class="k">else</span> <span class="s2">&quot;</span><span class="si">{:012}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</spa
+        <span class="n">img_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">subdir</span><span class="p">,</span> <span class="n">file_name</span><span class="p">)</span>
+        <span class="n">img_id</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sample_id_to_coco_id</span><span class="p">[</span><span class="n">sample_id</span><span class="p">]</span>
+
+        <span class="n">annotation</span> <span class="o">=</span> <span class="p">{</span>
+            <span class="s2">&quot;target&quot;</span><span class="p">:</span> <span class="n">target</span><span class="p">,</span>
+            <span class="s2">&quot;crowd_target&quot;</span><span class="p">:</span> <span class="n">crowd_target</span><span class="p">,</span>
+            <span class="s2">&quot;target_segmentation&quot;</span><span class="p">:</span> <span class="n">target_segmentation</span><span class="p">,</span>
+            <span class="s2">&quot;initial_img_shape&quot;</span><span class="p">:</span> <span class="n">initial_img_shape</span><span class="p">,</span>
+            <span class="s2">&quot;resized_img_shape&quot;</span><span class="p">:</span> <span class="n">resized_img_shape</span><span class="p">,</span>
+            <span class="s2">&quot;img_path&quot;</span><span class="p">:</span> <span class="n">img_path</span><span class="p">,</span>
+            <span class="s2">&quot;id&quot;</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">img_id</span><span class="p">]),</span>
+        <span class="p">}</span>
+        <span class="k">return</span> <span class="n">annotation</span></div>
+
+
+<span class="k">def</span> <span class="nf">remove_useless_info</span><span class="p">(</span><span class="n">coco</span><span class="p">,</span> <span class="n">use_seg_info</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Remove useless info in coco dataset. COCO object is modified inplace.</span>
 <span class="sd">    Remove useless info in coco dataset. COCO object is modified inplace.</span>
 <span class="sd">    This function is mainly used for saving memory (save about 30% mem).</span>
 <span class="sd">    This function is mainly used for saving memory (save about 30% mem).</span>
@@ -402,7 +255,7 @@
             <span class="n">img</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;flickr_url&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
             <span class="n">img</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;flickr_url&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
         <span class="k">if</span> <span class="s2">&quot;annotations&quot;</span> <span class="ow">in</span> <span class="n">coco</span><span class="o">.</span><span class="n">dataset</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">use_seg_info</span><span class="p">:</span>
         <span class="k">if</span> <span class="s2">&quot;annotations&quot;</span> <span class="ow">in</span> <span class="n">coco</span><span class="o">.</span><span class="n">dataset</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">use_seg_info</span><span class="p">:</span>
             <span class="k">for</span> <span class="n">anno</span> <span class="ow">in</span> <span class="n">coco</span><span class="o">.</span><span class="n">dataset</span><span class="p">[</span><span class="s2">&quot;annotations&quot;</span><span class="p">]:</span>
             <span class="k">for</span> <span class="n">anno</span> <span class="ow">in</span> <span class="n">coco</span><span class="o">.</span><span class="n">dataset</span><span class="p">[</span><span class="s2">&quot;annotations&quot;</span><span class="p">]:</span>
-                <span class="n">anno</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;segmentation&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span></div>
+                <span class="n">anno</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;segmentation&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
 </pre></div>
 </pre></div>
 
 
            </div>
            </div>
@@ -432,4 +285,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.datasets.detection_datasets.detection_dataset &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.datasets.detection_datasets.detection_dataset &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
+        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
+        <script src="../../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -93,21 +95,25 @@
 <span class="kn">import</span> <span class="nn">cv2</span>
 <span class="kn">import</span> <span class="nn">cv2</span>
 <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
 <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
 <span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
 <span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
+<span class="kn">from</span> <span class="nn">copy</span> <span class="kn">import</span> <span class="n">deepcopy</span>
 <span class="kn">import</span> <span class="nn">hashlib</span>
 <span class="kn">import</span> <span class="nn">hashlib</span>
 
 
 <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
 <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
 <span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
 <span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
 <span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">Dataset</span>
 <span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">Dataset</span>
 
 
+<span class="kn">from</span> <span class="nn">super_gradients.common.decorators.factory_decorator</span> <span class="kn">import</span> <span class="n">resolve_param</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">get_cls_posx_in_target</span><span class="p">,</span> <span class="n">DetectionTargetsFormat</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">get_cls_posx_in_target</span><span class="p">,</span> <span class="n">DetectionTargetsFormat</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.transforms.transforms</span> <span class="kn">import</span> <span class="n">DetectionTransform</span><span class="p">,</span> <span class="n">DetectionTargetsFormatTransform</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.transforms.transforms</span> <span class="kn">import</span> <span class="n">DetectionTransform</span><span class="p">,</span> <span class="n">DetectionTargetsFormatTransform</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.exceptions.dataset_exceptions</span> <span class="kn">import</span> <span class="n">EmptyDatasetException</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.exceptions.dataset_exceptions</span> <span class="kn">import</span> <span class="n">EmptyDatasetException</span>
+<span class="kn">from</span> <span class="nn">super_gradients.common.factories.list_factory</span> <span class="kn">import</span> <span class="n">ListFactory</span>
+<span class="kn">from</span> <span class="nn">super_gradients.common.factories.transforms_factory</span> <span class="kn">import</span> <span class="n">TransformsFactory</span>
 
 
 <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
 <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
 
 
 
 
-<div class="viewcode-block" id="DetectionDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.DetectionDataset">[docs]</a><span class="k">class</span> <span class="nc">DetectionDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span>
+<div class="viewcode-block" id="DetectionDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.DetectionDataset">[docs]</a><span class="k">class</span> <span class="nc">DetectionDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;Detection dataset.</span>
     <span class="sd">&quot;&quot;&quot;Detection dataset.</span>
 
 
 <span class="sd">    This is a boilerplate class to facilitate the implementation of datasets.</span>
 <span class="sd">    This is a boilerplate class to facilitate the implementation of datasets.</span>
@@ -146,20 +152,21 @@
 <span class="sd">                                &gt; Therefore, we also have len(self) = 100</span>
 <span class="sd">                                &gt; Therefore, we also have len(self) = 100</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 
 
+    <span class="nd">@resolve_param</span><span class="p">(</span><span class="s2">&quot;transforms&quot;</span><span class="p">,</span> <span class="n">ListFactory</span><span class="p">(</span><span class="n">TransformsFactory</span><span class="p">()))</span>
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
-            <span class="bp">self</span><span class="p">,</span>
-            <span class="n">data_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
-            <span class="n">input_dim</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span>
-            <span class="n">original_target_format</span><span class="p">:</span> <span class="n">DetectionTargetsFormat</span><span class="p">,</span>
-            <span class="n">max_num_samples</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
-            <span class="n">cache</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
-            <span class="n">cache_path</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
-            <span class="n">transforms</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">DetectionTransform</span><span class="p">]</span> <span class="o">=</span> <span class="p">[],</span>
-            <span class="n">all_classes_list</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
-            <span class="n">class_inclusion_list</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
-            <span class="n">ignore_empty_annotations</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
-            <span class="n">target_fields</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
-            <span class="n">output_fields</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">data_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+        <span class="n">input_dim</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span>
+        <span class="n">original_target_format</span><span class="p">:</span> <span class="n">DetectionTargetsFormat</span><span class="p">,</span>
+        <span class="n">max_num_samples</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">cache</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+        <span class="n">cache_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">transforms</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">DetectionTransform</span><span class="p">]</span> <span class="o">=</span> <span class="p">[],</span>
+        <span class="n">all_classes_list</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">class_inclusion_list</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">ignore_empty_annotations</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
+        <span class="n">target_fields</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">output_fields</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
     <span class="p">):</span>
     <span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;Detection dataset.</span>
         <span class="sd">&quot;&quot;&quot;Detection dataset.</span>
 
 
@@ -169,7 +176,7 @@
 <span class="sd">                                        differ based on transforms.</span>
 <span class="sd">                                        differ based on transforms.</span>
 <span class="sd">        :param max_num_samples:         If not None, set the maximum size of the dataset by only indexing the first n annotations/images.</span>
 <span class="sd">        :param max_num_samples:         If not None, set the maximum size of the dataset by only indexing the first n annotations/images.</span>
 <span class="sd">        :param cache:                   Whether to cache images or not.</span>
 <span class="sd">        :param cache:                   Whether to cache images or not.</span>
-<span class="sd">        :param cache_path:              Path to the directory where cached images will be stored in an optimized format.</span>
+<span class="sd">        :param cache_dir:              Path to the directory where cached images will be stored in an optimized format.</span>
 <span class="sd">        :param transforms:              List of transforms to apply sequentially on sample.</span>
 <span class="sd">        :param transforms:              List of transforms to apply sequentially on sample.</span>
 <span class="sd">        :param all_classes_list:        All the class names.</span>
 <span class="sd">        :param all_classes_list:        All the class names.</span>
 <span class="sd">        :param class_inclusion_list:    If not None,every class not included will be ignored.</span>
 <span class="sd">        :param class_inclusion_list:    If not None,every class not included will be ignored.</span>
@@ -208,11 +215,12 @@
         <span class="k">if</span> <span class="s2">&quot;target&quot;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">target_fields</span><span class="p">:</span>
         <span class="k">if</span> <span class="s2">&quot;target&quot;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">target_fields</span><span class="p">:</span>
             <span class="k">raise</span> <span class="ne">KeyError</span><span class="p">(</span><span class="s1">&#39;&quot;target&quot; is expected to be in the fields to subclass but it was not included&#39;</span><span class="p">)</span>
             <span class="k">raise</span> <span class="ne">KeyError</span><span class="p">(</span><span class="s1">&#39;&quot;target&quot; is expected to be in the fields to subclass but it was not included&#39;</span><span class="p">)</span>
 
 
+        <span class="bp">self</span><span class="o">.</span><span class="n">_required_annotation_fields</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;target&quot;</span><span class="p">,</span> <span class="s2">&quot;img_path&quot;</span><span class="p">,</span> <span class="s2">&quot;resized_img_shape&quot;</span><span class="p">}</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">annotations</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache_annotations</span><span class="p">()</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">annotations</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache_annotations</span><span class="p">()</span>
 
 
         <span class="bp">self</span><span class="o">.</span><span class="n">cache</span> <span class="o">=</span> <span class="n">cache</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">cache</span> <span class="o">=</span> <span class="n">cache</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">cache_path</span> <span class="o">=</span> <span class="n">cache_path</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">cached_imgs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache_images</span><span class="p">()</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cache</span> <span class="k">else</span> <span class="kc">None</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">cache_dir</span> <span class="o">=</span> <span class="n">cache_dir</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">cached_imgs_padded</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cache_images</span><span class="p">()</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cache</span> <span class="k">else</span> <span class="kc">None</span>
 
 
         <span class="bp">self</span><span class="o">.</span><span class="n">transforms</span> <span class="o">=</span> <span class="n">transforms</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">transforms</span> <span class="o">=</span> <span class="n">transforms</span>
 
 
@@ -231,7 +239,7 @@
 <span class="sd">        Please note that the targets should be resized according to self.input_dim!</span>
 <span class="sd">        Please note that the targets should be resized according to self.input_dim!</span>
 
 
 <span class="sd">        :param sample_id:   Id of the sample to load annotations from.</span>
 <span class="sd">        :param sample_id:   Id of the sample to load annotations from.</span>
-<span class="sd">        :return:            Annotation, a dict with any field but has to include at least &quot;target&quot; and &quot;img_path&quot;.</span>
+<span class="sd">        :return:            Annotation, a dict with any field but has to include at least the fields specified in self._required_annotation_fields.</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
         <span class="k">raise</span> <span class="ne">NotImplementedError</span>
         <span class="k">raise</span> <span class="ne">NotImplementedError</span>
 
 
@@ -246,8 +254,10 @@
                 <span class="k">break</span>
                 <span class="k">break</span>
 
 
             <span class="n">img_annotation</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_load_annotation</span><span class="p">(</span><span class="n">img_id</span><span class="p">)</span>
             <span class="n">img_annotation</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_load_annotation</span><span class="p">(</span><span class="n">img_id</span><span class="p">)</span>
-            <span class="k">if</span> <span class="s2">&quot;target&quot;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">img_annotation</span> <span class="ow">or</span> <span class="s2">&quot;img_path&quot;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">img_annotation</span><span class="p">:</span>
-                <span class="k">raise</span> <span class="ne">KeyError</span><span class="p">(</span><span class="s1">&#39;_load_annotation is expected to return at least the field &quot;target&quot; and &quot;img_path&quot;&#39;</span><span class="p">)</span>
+            <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">_required_annotation_fields</span><span class="o">.</span><span class="n">issubset</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">img_annotation</span><span class="o">.</span><span class="n">keys</span><span class="p">())):</span>
+                <span class="k">raise</span> <span class="ne">KeyError</span><span class="p">(</span>
+                    <span class="sa">f</span><span class="s2">&quot;_load_annotation is expected to return at least the fields </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_required_annotation_fields</span><span class="si">}</span><span class="s2"> &quot;</span> <span class="sa">f</span><span class="s2">&quot;but got </span><span class="si">{</span><span class="nb">set</span><span class="p">(</span><span class="n">img_annotation</span><span
+                <span class="p">)</span>
 
 
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">class_inclusion_list</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">class_inclusion_list</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
                 <span class="n">img_annotation</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sub_class_annotation</span><span class="p">(</span><span class="n">img_annotation</span><span class="p">)</span>
                 <span class="n">img_annotation</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sub_class_annotation</span><span class="p">(</span><span class="n">img_annotation</span><span class="p">)</span>
@@ -258,8 +268,9 @@
             <span class="n">annotations</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">img_annotation</span><span class="p">)</span>
             <span class="n">annotations</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">img_annotation</span><span class="p">)</span>
 
 
         <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">annotations</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
         <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">annotations</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
-            <span class="k">raise</span> <span class="n">EmptyDatasetException</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Out of </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">n_available_samples</span><span class="si">}</span><span class="s2"> images, not a single one was found with&quot;</span>
-                                        <span class="sa">f</span><span class="s2">&quot;any of these classes: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">class_inclusion_list</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
+            <span class="k">raise</span> <span class="n">EmptyDatasetException</span><span class="p">(</span>
+                <span class="sa">f</span><span class="s2">&quot;Out of </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">n_available_samples</span><span class="si">}</span><span class="s2"> images, not a single one was found with&quot;</span> <span class="sa">f</span><span class="s2">&quot;any of these classes: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">class_inclusion_list</span><span cla
+            <span class="p">)</span>
         <span class="k">return</span> <span class="n">annotations</span>
         <span class="k">return</span> <span class="n">annotations</span>
 
 
     <span class="k">def</span> <span class="nf">_sub_class_annotation</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">annotation</span><span class="p">:</span> <span class="nb">dict</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="nb">dict</span><span class="p">,</span> <span class="kc">None</span><span class="p">]:</span>
     <span class="k">def</span> <span class="nf">_sub_class_annotation</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">annotation</span><span class="p">:</span> <span class="nb">dict</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="nb">dict</span><span class="p">,</span> <span class="kc">None</span><span class="p">]:</span>
@@ -296,16 +307,17 @@
         <span class="sd">&quot;&quot;&quot;Cache the images. The cached image are stored in a file to be loaded faster mext time.</span>
         <span class="sd">&quot;&quot;&quot;Cache the images. The cached image are stored in a file to be loaded faster mext time.</span>
 <span class="sd">        :return: Cached images</span>
 <span class="sd">        :return: Cached images</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
-        <span class="n">cache_path</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cache_path</span><span class="p">)</span>
-        <span class="k">if</span> <span class="n">cache_path</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
-            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;You must specify a cache_path if you want to cache your images.&quot;</span>
-                             <span class="s2">&quot;If you did not mean to use cache, please set cache=False &quot;</span><span class="p">)</span>
-        <span class="n">cache_path</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
-
-        <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">********************************************************************************</span><span class="se">\n</span><span class="s2">&quot;</span>
-                       <span class="s2">&quot;You are using cached images in RAM to accelerate training.</span><span class="se">\n</span><span class="s2">&quot;</span>
-                       <span class="s2">&quot;This requires large system RAM.</span><span class="se">\n</span><span class="s2">&quot;</span>
-                       <span class="s2">&quot;********************************************************************************&quot;</span><span class="p">)</span>
+        <span class="n">cache_dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cache_dir</span><span class="p">)</span>
+        <span class="k">if</span> <span class="n">cache_dir</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;You must specify a cache_dir if you want to cache your images.&quot;</span> <span class="s2">&quot;If you did not mean to use cache, please set cache=False &quot;</span><span class="p">)</span>
+        <span class="n">cache_dir</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+
+        <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
+            <span class="s2">&quot;</span><span class="se">\n</span><span class="s2">********************************************************************************</span><span class="se">\n</span><span class="s2">&quot;</span>
+            <span class="s2">&quot;You are using cached images in RAM to accelerate training.</span><span class="se">\n</span><span class="s2">&quot;</span>
+            <span class="s2">&quot;This requires large system RAM.</span><span class="se">\n</span><span class="s2">&quot;</span>
+            <span class="s2">&quot;********************************************************************************&quot;</span>
+        <span class="p">)</span>
 
 
         <span class="n">max_h</span><span class="p">,</span> <span class="n">max_w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
         <span class="n">max_h</span><span class="p">,</span> <span class="n">max_w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
 
 
@@ -314,10 +326,10 @@
         <span class="k">for</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">annotations</span><span class="p">:</span>
         <span class="k">for</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">annotations</span><span class="p">:</span>
             <span class="n">values_to_hash</span> <span class="o">=</span> <span class="p">[</span><span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;resized_img_shape&quot;</span><span class="p">][</span><span class="mi">0</span><span class="p">],</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;resized_img_shape&quot;</span><span class="p">][</span><span class="mi">1</span><span class="p">],</span> <span class="n">Path</span><
             <span class="n">values_to_hash</span> <span class="o">=</span> <span class="p">[</span><span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;resized_img_shape&quot;</span><span class="p">][</span><span class="mi">0</span><span class="p">],</span> <span class="n">annotation</span><span class="p">[</span><span class="s2">&quot;resized_img_shape&quot;</span><span class="p">][</span><span class="mi">1</span><span class="p">],</span> <span class="n">Path</span><
             <span class="k">for</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">values_to_hash</span><span class="p">:</span>
             <span class="k">for</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">values_to_hash</span><span class="p">:</span>
-                <span class="nb">hash</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">value</span><span class="p">)</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="s1">&#39;utf-8&#39;</span><span class="p">))</span>
+                <span class="nb">hash</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">value</span><span class="p">)</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="s2">&quot;utf-8&quot;</span><span class="p">))</span>
         <span class="n">cache_hash</span> <span class="o">=</span> <span class="nb">hash</span><span class="o">.</span><span class="n">hexdigest</span><span class="p">()</span>
         <span class="n">cache_hash</span> <span class="o">=</span> <span class="nb">hash</span><span class="o">.</span><span class="n">hexdigest</span><span class="p">()</span>
 
 
-        <span class="n">img_resized_cache_path</span> <span class="o">=</span> <span class="n">cache_path</span> <span class="o">/</span> <span class="sa">f</span><span class="s2">&quot;img_resized_cache_</span><span class="si">{</span><span class="n">cache_hash</span><span class="si">}</span><span class="s2">.array&quot;</span>
+        <span class="n">img_resized_cache_path</span> <span class="o">=</span> <span class="n">cache_dir</span> <span class="o">/</span> <span class="sa">f</span><span class="s2">&quot;img_resized_cache_</span><span class="si">{</span><span class="n">cache_hash</span><span class="si">}</span><span class="s2">.array&quot;</span>
 
 
         <span class="k">if</span> <span class="ow">not</span> <span class="n">img_resized_cache_path</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
         <span class="k">if</span> <span class="ow">not</span> <span class="n">img_resized_cache_path</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
             <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Caching images for the first time. Be aware that this will stay in the disk until you delete it yourself.&quot;</span><span class="p">)</span>
             <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Caching images for the first time. Be aware that this will stay in the disk until you delete it yourself.&quot;</span><span class="p">)</span>
@@ -325,8 +337,7 @@
             <span class="n">loaded_images</span> <span class="o">=</span> <span class="n">ThreadPool</span><span class="p">(</span><span class="n">NUM_THREADs</span><span class="p">)</span><span class="o">.</span><span class="n">imap</span><span class="p">(</span><span class="n">func</span><span class="o">=</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_load_resized_img</span><span class="
             <span class="n">loaded_images</span> <span class="o">=</span> <span class="n">ThreadPool</span><span class="p">(</span><span class="n">NUM_THREADs</span><span class="p">)</span><span class="o">.</span><span class="n">imap</span><span class="p">(</span><span class="n">func</span><span class="o">=</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_load_resized_img</span><span class="
 
 
             <span class="c1"># Initialize placeholder for images</span>
             <span class="c1"># Initialize placeholder for images</span>
-            <span class="n">cached_imgs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">memmap</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">img_resized_cache_path</span><span class="p">),</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">),</span> <span class="n">max_h<
-                                    <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;w+&quot;</span><span class="p">)</span>
+            <span class="n">cached_imgs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">memmap</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">img_resized_cache_path</span><span class="p">),</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">),</span> <span class="n">max_h<
 
 
             <span class="c1"># Store images in the placeholder</span>
             <span class="c1"># Store images in the placeholder</span>
             <span class="n">loaded_images_pbar</span> <span class="o">=</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">enumerate</span><span class="p">(</span><span class="n">loaded_images</span><span class="p">),</span> <span class="n">total</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">))</span>
             <span class="n">loaded_images_pbar</span> <span class="o">=</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">enumerate</span><span class="p">(</span><span class="n">loaded_images</span><span class="p">),</span> <span class="n">total</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">))</span>
@@ -338,8 +349,7 @@
             <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;You are using cached imgs!&quot;</span><span class="p">)</span>
             <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;You are using cached imgs!&quot;</span><span class="p">)</span>
 
 
         <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Loading cached imgs...&quot;</span><span class="p">)</span>
         <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Loading cached imgs...&quot;</span><span class="p">)</span>
-        <span class="n">cached_imgs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">memmap</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">img_resized_cache_path</span><span class="p">),</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">),</span> <span class="n">max_h</spa
-                                <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;r+&quot;</span><span class="p">)</span>
+        <span class="n">cached_imgs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">memmap</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">img_resized_cache_path</span><span class="p">),</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">),</span> <span class="n">max_h</spa
         <span class="k">return</span> <span class="n">cached_imgs</span>
         <span class="k">return</span> <span class="n">cached_imgs</span>
 
 
     <span class="k">def</span> <span class="nf">_load_resized_img</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
     <span class="k">def</span> <span class="nf">_load_resized_img</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
@@ -366,14 +376,13 @@
         <span class="n">img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">imread</span><span class="p">(</span><span class="n">img_file</span><span class="p">)</span>
         <span class="n">img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">imread</span><span class="p">(</span><span class="n">img_file</span><span class="p">)</span>
 
 
         <span class="k">if</span> <span class="n">img</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">img</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
-            <span class="k">raise</span> <span class="ne">FileNotFoundError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">img_file</span><span class="si">}</span><span class="s2"> was no found. Please make sure that the dataset was&quot;</span>
-                                    <span class="sa">f</span><span class="s2">&quot;downloaded and that the path is correct&quot;</span><span class="p">)</span>
+            <span class="k">raise</span> <span class="ne">FileNotFoundError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">img_file</span><span class="si">}</span><span class="s2"> was no found. Please make sure that the dataset was&quot;</span> <span class="sa">f</span><span class="s2">&quot;downloaded and that the path is correct&quot;</span><span class="p">)</span>
         <span class="k">return</span> <span class="n">img</span>
         <span class="k">return</span> <span class="n">img</span>
 
 
     <span class="k">def</span> <span class="fm">__del__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
     <span class="k">def</span> <span class="fm">__del__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;Clear the cached images&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;Clear the cached images&quot;&quot;&quot;</span>
-        <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s2">&quot;cached_imgs&quot;</span><span class="p">):</span>
-            <span class="k">del</span> <span class="bp">self</span><span class="o">.</span><span class="n">cached_imgs</span>
+        <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s2">&quot;cached_imgs_padded&quot;</span><span class="p">):</span>
+            <span class="k">del</span> <span class="bp">self</span><span class="o">.</span><span class="n">cached_imgs_padded</span>
 
 
     <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
     <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;Get the length of the dataset.&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;Get the length of the dataset.&quot;&quot;&quot;</span>
@@ -385,34 +394,37 @@
         <span class="n">sample</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">apply_transforms</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_sample</span><span class="p">(</span><span class="n">index</span><span class="p">))</span>
         <span class="n">sample</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">apply_transforms</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_sample</span><span class="p">(</span><span class="n">index</span><span class="p">))</span>
         <span class="k">for</span> <span class="n">field</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_fields</span><span class="p">:</span>
         <span class="k">for</span> <span class="n">field</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_fields</span><span class="p">:</span>
             <span class="k">if</span> <span class="n">field</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">sample</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
             <span class="k">if</span> <span class="n">field</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">sample</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
-                <span class="k">raise</span> <span class="ne">KeyError</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;The field </span><span class="si">{</span><span class="n">field</span><span class="si">}</span><span class="s1"> must be present in the sample but was not found.&#39;</span>
-                               <span class="s1">&#39;Please check the output fields of your transforms.&#39;</span><span class="p">)</span>
+                <span class="k">raise</span> <span class="ne">KeyError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;The field </span><span class="si">{</span><span class="n">field</span><span class="si">}</span><span class="s2"> must be present in the sample but was not found.&quot;</span> <span class="s2">&quot;Please check the output fields of your transforms.&quot;</span><span class="p">)</span>
         <span class="k">return</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">sample</span><span class="p">[</span><span class="n">field</span><span class="p">]</span> <span class="k">for</span> <span class="n">field</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_fields</span><span class="p">)</span>
         <span class="k">return</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">sample</span><span class="p">[</span><span class="n">field</span><span class="p">]</span> <span class="k">for</span> <span class="n">field</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_fields</span><span class="p">)</span>
 
 
-<div class="viewcode-block" id="DetectionDataset.get_random_item"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.DetectionDataset.get_random_item">[docs]</a>    <span class="k">def</span> <span class="nf">get_random_item</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+<div class="viewcode-block" id="DetectionDataset.get_random_item"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.DetectionDataset.get_random_item">[docs]</a>    <span class="k">def</span> <span class="nf">get_random_item</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
         <span class="k">return</span> <span class="bp">self</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_random_index</span><span class="p">()]</span></div>
         <span class="k">return</span> <span class="bp">self</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">_random_index</span><span class="p">()]</span></div>
 
 
-<div class="viewcode-block" id="DetectionDataset.get_sample"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.DetectionDataset.get_sample">[docs]</a>    <span class="k">def</span> <span class="nf">get_sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o
+<div class="viewcode-block" id="DetectionDataset.get_sample"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.DetectionDataset.get_sample">[docs]</a>    <span class="k">def</span> <span class="nf">get_sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="
         <span class="sd">&quot;&quot;&quot;Get raw sample, before any transform (beside subclassing).</span>
         <span class="sd">&quot;&quot;&quot;Get raw sample, before any transform (beside subclassing).</span>
 <span class="sd">        :param index:   Image index</span>
 <span class="sd">        :param index:   Image index</span>
 <span class="sd">        :return:        Sample, i.e. a dictionary including at least &quot;image&quot; and &quot;target&quot;</span>
 <span class="sd">        :return:        Sample, i.e. a dictionary including at least &quot;image&quot; and &quot;target&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
         <span class="n">img</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_resized_image</span><span class="p">(</span><span class="n">index</span><span class="p">)</span>
         <span class="n">img</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_resized_image</span><span class="p">(</span><span class="n">index</span><span class="p">)</span>
-        <span class="n">annotation</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">annotations</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
+        <span class="n">annotation</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">annotations</span><span class="p">[</span><span class="n">index</span><span class="p">])</span>
         <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;image&quot;</span><span class="p">:</span> <span class="n">img</span><span class="p">,</span> <span class="o">**</span><span class="n">annotation</span><span class="p">}</span></div>
         <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;image&quot;</span><span class="p">:</span> <span class="n">img</span><span class="p">,</span> <span class="o">**</span><span class="n">annotation</span><span class="p">}</span></div>
 
 
-<div class="viewcode-block" id="DetectionDataset.get_resized_image"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.DetectionDataset.get_resized_image">[docs]</a>    <span class="k">def</span> <span class="nf">get_resized_image</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)<
+<div class="viewcode-block" id="DetectionDataset.get_resized_image"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.DetectionDataset.get_resized_image">[docs]</a>    <span class="k">def</span> <span class="nf">get_resized_image</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
-<span class="sd">        Get the resized image at a specific sample_id, either from cache or by loading from disk, based on self.cached_imgs</span>
+<span class="sd">        Get the resized image (i.e. either width or height reaches its input_dim) at a specific sample_id,</span>
+<span class="sd">        either from cache or by loading from disk, based on self.cached_imgs_padded</span>
 <span class="sd">        :param index:  Image index</span>
 <span class="sd">        :param index:  Image index</span>
 <span class="sd">        :return:       Resized image</span>
 <span class="sd">        :return:       Resized image</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
         <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cache</span><span class="p">:</span>
         <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cache</span><span class="p">:</span>
-            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">cached_imgs</span><span class="p">[</span><span class="n">index</span><span class="p">]</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
+            <span class="n">padded_image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cached_imgs_padded</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
+            <span class="n">resized_height</span><span class="p">,</span> <span class="n">resized_width</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">annotations</span><span class="p">[</span><span class="n">index</span><span class="p">][</span><span class="s2">&quot;resized_img_shape&quot;</span><span class="p">]</span>
+            <span class="n">resized_image</span> <span class="o">=</span> <span class="n">padded_image</span><span class="p">[:</span><span class="n">resized_height</span><span class="p">,</span> <span class="p">:</span><span class="n">resized_width</span><span class="p">,</span> <span class="p">:]</span>
+            <span class="k">return</span> <span class="n">resized_image</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
         <span class="k">else</span><span class="p">:</span>
         <span class="k">else</span><span class="p">:</span>
             <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_load_resized_img</span><span class="p">(</span><span class="n">index</span><span class="p">)</span></div>
             <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_load_resized_img</span><span class="p">(</span><span class="n">index</span><span class="p">)</span></div>
 
 
-<div class="viewcode-block" id="DetectionDataset.apply_transforms"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.DetectionDataset.apply_transforms">[docs]</a>    <span class="k">def</span> <span class="nf">apply_transforms</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</s
+<div class="viewcode-block" id="DetectionDataset.apply_transforms"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.DetectionDataset.apply_transforms">[docs]</a>    <span class="k">def</span> <span class="nf">apply_transforms</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</sp
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        Applies self.transforms sequentially to sample</span>
 <span class="sd">        Applies self.transforms sequentially to sample</span>
 
 
@@ -429,17 +441,14 @@
             <span class="n">sample</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;additional_samples&quot;</span><span class="p">)</span>  <span class="c1"># additional_samples is not useful after the transform</span>
             <span class="n">sample</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;additional_samples&quot;</span><span class="p">)</span>  <span class="c1"># additional_samples is not useful after the transform</span>
         <span class="k">return</span> <span class="n">sample</span></div>
         <span class="k">return</span> <span class="n">sample</span></div>
 
 
-    <span class="k">def</span> <span class="nf">_add_additional_inputs_for_transform</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Union</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">An
-                                             <span class="n">transform</span><span class="p">:</span> <span class="n">DetectionTransform</span><span class="p">):</span>
+    <span class="k">def</span> <span class="nf">_add_additional_inputs_for_transform</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Union</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">An
         <span class="sd">&quot;&quot;&quot;Add additional inputs required by a transform to the sample&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;Add additional inputs required by a transform to the sample&quot;&quot;&quot;</span>
-        <span class="n">additional_samples_count</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">additional_samples_count</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">transform</span><span class="p">,</span>
-                                                                                 <span class="s2">&quot;additional_samples_count&quot;</span><span class="p">)</span> <span class="k">else</span> <span class="mi">0</span>
+        <span class="n">additional_samples_count</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">additional_samples_count</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">transform</span><span class="p">,</span> <span class="s2">&quot;additional_samples_count&quot;</span><span class="p">)</span> <span class="k">else</span> <span class="mi">0</span>
         <span class="n">non_empty_annotations</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">non_empty_annotations</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">transform</span><span class="p">,</span> <span class="s2">&quot;non_empty_annotations&quot;</span><span class="p">)</span> <span class="k">else</span> <span class="kc">False</span>
         <span class="n">non_empty_annotations</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">non_empty_annotations</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">transform</span><span class="p">,</span> <span class="s2">&quot;non_empty_annotations&quot;</span><span class="p">)</span> <span class="k">else</span> <span class="kc">False</span>
         <span class="n">additional_samples</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_random_samples</span><span class="p">(</span><span class="n">additional_samples_count</span><span class="p">,</span> <span class="n">non_empty_annotations</span><span class="p">)</span>
         <span class="n">additional_samples</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_random_samples</span><span class="p">(</span><span class="n">additional_samples_count</span><span class="p">,</span> <span class="n">non_empty_annotations</span><span class="p">)</span>
         <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;additional_samples&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">additional_samples</span>
         <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;additional_samples&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">additional_samples</span>
 
 
-<div class="viewcode-block" id="DetectionDataset.get_random_samples"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.DetectionDataset.get_random_samples">[docs]</a>    <span class="k">def</span> <span class="nf">get_random_samples</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">count</span><span class="p">:</span> <span class="nb">int</span><span class="p"
-                           <span class="n">non_empty_annotations_only</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Union</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span 
+<div class="viewcode-block" id="DetectionDataset.get_random_samples"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.DetectionDataset.get_random_samples">[docs]</a>    <span class="k">def</span> <span class="nf">get_random_samples</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">count</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">no
         <span class="sd">&quot;&quot;&quot;Load random samples.</span>
         <span class="sd">&quot;&quot;&quot;Load random samples.</span>
 
 
 <span class="sd">        :param count: The number of samples wanted</span>
 <span class="sd">        :param count: The number of samples wanted</span>
@@ -448,7 +457,7 @@
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
         <span class="k">return</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">get_random_sample</span><span class="p">(</span><span class="n">non_empty_annotations_only</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">count</span><span class="p">)]</span></div>
         <span class="k">return</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">get_random_sample</span><span class="p">(</span><span class="n">non_empty_annotations_only</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">count</span><span class="p">)]</span></div>
 
 
-<div class="viewcode-block" id="DetectionDataset.get_random_sample"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.DetectionDataset.get_random_sample">[docs]</a>    <span class="k">def</span> <span class="nf">get_random_sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">non_empty_annotations_only</span><span class="p">:</span> <span class="nb">bool</s
+<div class="viewcode-block" id="DetectionDataset.get_random_sample"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.DetectionDataset.get_random_sample">[docs]</a>    <span class="k">def</span> <span class="nf">get_random_sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">non_empty_annotations_only</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span
         <span class="k">if</span> <span class="n">non_empty_annotations_only</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">non_empty_annotations_only</span><span class="p">:</span>
             <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_sample</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_random_non_empty_annotation_available_indexes</span><span class="p">())</span>
             <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_sample</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_random_non_empty_annotation_available_indexes</span><span class="p">())</span>
         <span class="k">else</span><span class="p">:</span>
         <span class="k">else</span><span class="p">:</span>
@@ -475,23 +484,22 @@
                 <span class="n">target_format</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">output_format</span>
                 <span class="n">target_format</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">output_format</span>
         <span class="k">return</span> <span class="n">target_format</span>
         <span class="k">return</span> <span class="n">target_format</span>
 
 
-<div class="viewcode-block" id="DetectionDataset.plot"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.DetectionDataset.plot">[docs]</a>    <span class="k">def</span> <span class="nf">plot</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_samples_per_plot</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi"
+<div class="viewcode-block" id="DetectionDataset.plot"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.DetectionDataset.plot">[docs]</a>    <span class="k">def</span> <span class="nf">plot</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_samples_per_plot</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">16</span><span class="p">,<
         <span class="sd">&quot;&quot;&quot;Combine samples of images with bbox into plots and display the result.</span>
         <span class="sd">&quot;&quot;&quot;Combine samples of images with bbox into plots and display the result.</span>
 
 
-<span class="sd">            :param max_samples_per_plot:    Maximum number of images to be displayed per plot</span>
-<span class="sd">            :param n_plots:                 Number of plots to display (each plot being a combination of img with bbox)</span>
-<span class="sd">            :param plot_transformed_data:   If True, the plot will be over samples after applying transforms (i.e. on __getitem__).</span>
-<span class="sd">                                            If False, the plot will be over the raw samples (i.e. on get_sample)</span>
-<span class="sd">            :return:</span>
+<span class="sd">        :param max_samples_per_plot:    Maximum number of images to be displayed per plot</span>
+<span class="sd">        :param n_plots:                 Number of plots to display (each plot being a combination of img with bbox)</span>
+<span class="sd">        :param plot_transformed_data:   If True, the plot will be over samples after applying transforms (i.e. on __getitem__).</span>
+<span class="sd">                                        If False, the plot will be over the raw samples (i.e. on get_sample)</span>
+<span class="sd">        :return:</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
         <span class="n">plot_counter</span> <span class="o">=</span> <span class="mi">0</span>
         <span class="n">plot_counter</span> <span class="o">=</span> <span class="mi">0</span>
         <span class="n">input_format</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_target_format</span> <span class="k">if</span> <span class="n">plot_transformed_data</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">original_target_format</span>
         <span class="n">input_format</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_target_format</span> <span class="k">if</span> <span class="n">plot_transformed_data</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">original_target_format</span>
-        <span class="n">target_format_transform</span> <span class="o">=</span> <span class="n">DetectionTargetsFormatTransform</span><span class="p">(</span><span class="n">input_format</span><span class="o">=</span><span class="n">input_format</span><span class="p">,</span>
-                                                                  <span class="n">output_format</span><span class="o">=</span><span class="n">DetectionTargetsFormat</span><span class="o">.</span><span class="n">XYXY_LABEL</span><span class="p">)</span>
+        <span class="n">target_format_transform</span> <span class="o">=</span> <span class="n">DetectionTargetsFormatTransform</span><span class="p">(</span><span class="n">input_format</span><span class="o">=</span><span class="n">input_format</span><span class="p">,</span> <span class="n">output_format</span><span class="o">=</span><span class="n">DetectionTargetsFormat</span><span class="o">.</span><span class="n">XYXY_LABEL</span><span class="p">)</span>
 
 
         <span class="k">for</span> <span class="n">plot_i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_plots</span><span class="p">):</span>
         <span class="k">for</span> <span class="n">plot_i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_plots</span><span class="p">):</span>
             <span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
             <span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
-            <span class="n">n_subplot</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">max_samples_per_plot</span> <span class="o">**</span> <span class="mf">0.5</span><span class="p">))</span>
+            <span class="n">n_subplot</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">max_samples_per_plot</span><span class="o">**</span><span class="mf">0.5</span><span class="p">))</span>
             <span class="k">for</span> <span class="n">img_i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">max_samples_per_plot</span><span class="p">):</span>
             <span class="k">for</span> <span class="n">img_i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">max_samples_per_plot</span><span class="p">):</span>
                 <span class="n">index</span> <span class="o">=</span> <span class="n">img_i</span> <span class="o">+</span> <span class="n">plot_i</span> <span class="o">*</span> <span class="mi">16</span>
                 <span class="n">index</span> <span class="o">=</span> <span class="n">img_i</span> <span class="o">+</span> <span class="n">plot_i</span> <span class="o">*</span> <span class="mi">16</span>
 
 
@@ -509,9 +517,9 @@
 
 
                 <span class="c1"># shape = [n_box x 4] (We remove padded boxes, which corresponds to boxes with only 0)</span>
                 <span class="c1"># shape = [n_box x 4] (We remove padded boxes, which corresponds to boxes with only 0)</span>
                 <span class="n">boxes</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[(</span><span class="n">boxes</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)]</span>
                 <span class="n">boxes</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[(</span><span class="n">boxes</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">any</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)]</span>
-                <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="n">n_subplot</span><span class="p">,</span> <span class="n">n_subplot</span><span class="p">,</span> <span class="n">img_i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image</span><span class="p">)</span>
-                <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">boxes</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]]</span><span class="o">.</span><span class="n">T</span><span
-                <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s1">&#39;off&#39;</span><span class="p">)</span>
+                <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="n">n_subplot</span><span class="p">,</span> <span class="n">n_subplot</span><span class="p">,</span> <span class="n">img_i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image</span><span class="p">[:,</span> <span class="p">:,</span> <spa
+                <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">boxes</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]]</span><span class="o">.</span><span class="n">T</span><span
+                <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">&quot;off&quot;</span><span class="p">)</span>
             <span class="n">fig</span><span class="o">.</span><span class="n">tight_layout</span><span class="p">()</span>
             <span class="n">fig</span><span class="o">.</span><span class="n">tight_layout</span><span class="p">()</span>
             <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
             <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
             <span class="n">plt</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
             <span class="n">plt</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
@@ -548,4 +556,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.datasets.detection_datasets.pascal_voc_detection &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.datasets.detection_datasets.pascal_voc_detection &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
+        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
+        <script src="../../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -89,11 +91,15 @@
 <span></span><span class="kn">import</span> <span class="nn">os</span>
 <span></span><span class="kn">import</span> <span class="nn">os</span>
 <span class="kn">import</span> <span class="nn">glob</span>
 <span class="kn">import</span> <span class="nn">glob</span>
 <span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
 <span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Optional</span>
 <span class="kn">from</span> <span class="nn">xml.etree</span> <span class="kn">import</span> <span class="n">ElementTree</span>
 <span class="kn">from</span> <span class="nn">xml.etree</span> <span class="kn">import</span> <span class="n">ElementTree</span>
+
+<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">ConcatDataset</span>
 <span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
 <span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
 
 
 <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
 <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
 
 
+<span class="kn">from</span> <span class="nn">super_gradients.training.transforms.transforms</span> <span class="kn">import</span> <span class="n">DetectionTransform</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.utils</span> <span class="kn">import</span> <span class="n">download_and_untar_from_url</span><span class="p">,</span> <span class="n">get_image_size_from_path</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.utils</span> <span class="kn">import</span> <span class="n">download_and_untar_from_url</span><span class="p">,</span> <span class="n">get_image_size_from_path</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.detection_datasets.detection_dataset</span> <span class="kn">import</span> <span class="n">DetectionDataset</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.detection_datasets.detection_dataset</span> <span class="kn">import</span> <span class="n">DetectionDataset</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">DetectionTargetsFormat</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">DetectionTargetsFormat</span>
@@ -103,19 +109,25 @@
 <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
 <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
 
 
 
 
-<div class="viewcode-block" id="PascalVOCDetectionDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.PascalVOCDetectionDataset">[docs]</a><span class="k">class</span> <span class="nc">PascalVOCDetectionDataset</span><span class="p">(</span><span class="n">DetectionDataset</span><span class="p">):</span>
+<div class="viewcode-block" id="PascalVOCDetectionDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.PascalVOCDetectionDataset">[docs]</a><span class="k">class</span> <span class="nc">PascalVOCDetectionDataset</span><span class="p">(</span><span class="n">DetectionDataset</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;Dataset for Pascal VOC object detection&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;Dataset for Pascal VOC object detection&quot;&quot;&quot;</span>
 
 
-    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">images_sub_directory</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">images_sub_directory</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">download</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span>
         <span class="sd">&quot;&quot;&quot;Dataset for Pascal VOC object detection</span>
         <span class="sd">&quot;&quot;&quot;Dataset for Pascal VOC object detection</span>
 
 
 <span class="sd">        :param images_sub_directory:    Sub directory of data_dir that includes images.</span>
 <span class="sd">        :param images_sub_directory:    Sub directory of data_dir that includes images.</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
+
         <span class="bp">self</span><span class="o">.</span><span class="n">images_sub_directory</span> <span class="o">=</span> <span class="n">images_sub_directory</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">images_sub_directory</span> <span class="o">=</span> <span class="n">images_sub_directory</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">img_and_target_path_list</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">img_and_target_path_list</span> <span class="o">=</span> <span class="kc">None</span>
-
-        <span class="n">kwargs</span><span class="p">[</span><span class="s1">&#39;all_classes_list&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">PASCAL_VOC_2012_CLASSES_LIST</span>
-        <span class="n">kwargs</span><span class="p">[</span><span class="s1">&#39;original_target_format&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">DetectionTargetsFormat</span><span class="o">.</span><span class="n">XYXY_LABEL</span>
+        <span class="n">data_dir</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;data_dir&quot;</span><span class="p">)</span>
+        <span class="k">if</span> <span class="n">data_dir</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Must pass data_dir != None through **kwargs&quot;</span><span class="p">)</span>
+        <span class="k">if</span> <span class="n">download</span><span class="p">:</span>
+            <span class="n">PascalVOCDetectionDataset</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span>
+
+        <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;original_target_format&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">DetectionTargetsFormat</span><span class="o">.</span><span class="n">XYXY_LABEL</span>
+        <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;all_classes_list&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">PASCAL_VOC_2012_CLASSES_LIST</span>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
 
 
     <span class="k">def</span> <span class="nf">_setup_data_source</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">_setup_data_source</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
@@ -125,9 +137,11 @@
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
         <span class="n">img_files_folder</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">images_sub_directory</span>
         <span class="n">img_files_folder</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">images_sub_directory</span>
         <span class="k">if</span> <span class="ow">not</span> <span class="n">Path</span><span class="p">(</span><span class="n">img_files_folder</span><span class="p">)</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
         <span class="k">if</span> <span class="ow">not</span> <span class="n">Path</span><span class="p">(</span><span class="n">img_files_folder</span><span class="p">)</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
-            <span class="k">raise</span> <span class="ne">FileNotFoundError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="si">}</span><span class="s2"> does not include </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">images_sub_directory</span><span class="si">}</span><span class="s2">.
-                                    <span class="sa">f</span><span class="s2">&quot;Please make sure that f</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="si">}</span><span class="s2"> refers to PascalVOC dataset and that &quot;</span>
-                                    <span class="s2">&quot;it was downloaded using PascalVOCDetectionDataSetV2.download()&quot;</span><span class="p">)</span>
+            <span class="k">raise</span> <span class="ne">FileNotFoundError</span><span class="p">(</span>
+                <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="si">}</span><span class="s2"> does not include </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">images_sub_directory</span><span class="si">}</span><span class="s2">. &quot;</span>
+                <span class="sa">f</span><span class="s2">&quot;Please make sure that f</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">data_dir</span><span class="si">}</span><span class="s2"> refers to PascalVOC dataset and that &quot;</span>
+                <span class="s2">&quot;it was downloaded using PascalVOCDetectionDataSetV2.download()&quot;</span>
+            <span class="p">)</span>
 
 
         <span class="n">img_files</span> <span class="o">=</span> <span class="n">glob</span><span class="o">.</span><span class="n">glob</span><span class="p">(</span><span class="n">img_files_folder</span> <span class="o">+</span> <span class="s2">&quot;*.jpg&quot;</span><span class="p">)</span>
         <span class="n">img_files</span> <span class="o">=</span> <span class="n">glob</span><span class="o">.</span><span class="n">glob</span><span class="p">(</span><span class="n">img_files_folder</span> <span class="o">+</span> <span class="s2">&quot;*.jpg&quot;</span><span class="p">)</span>
         <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">img_files</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
         <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">img_files</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
@@ -135,15 +149,13 @@
 
 
         <span class="n">target_files</span> <span class="o">=</span> <span class="p">[</span><span class="n">img_file</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;images&quot;</span><span class="p">,</span> <span class="s2">&quot;labels&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;.jpg&quot;</span><span class="p">,</span> <span class="s2">&quo
         <span class="n">target_files</span> <span class="o">=</span> <span class="p">[</span><span class="n">img_file</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;images&quot;</span><span class="p">,</span> <span class="s2">&quot;labels&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;.jpg&quot;</span><span class="p">,</span> <span class="s2">&quo
 
 
-        <span class="n">img_and_target_path_list</span> <span class="o">=</span> <span class="p">[(</span><span class="n">img_file</span><span class="p">,</span> <span class="n">target_file</span><span class="p">)</span>
-                                    <span class="k">for</span> <span class="n">img_file</span><span class="p">,</span> <span class="n">target_file</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">img_files</span><span class="p">,</span> <span class="n">target_files</span><span class="p">)</span>
-                                    <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">target_file</span><span class="p">)]</span>
+        <span class="n">img_and_target_path_list</span> <span class="o">=</span> <span class="p">[(</span><span class="n">img_file</span><span class="p">,</span> <span class="n">target_file</span><span class="p">)</span> <span class="k">for</span> <span class="n">img_file</span><span class="p">,</span> <span class="n">target_file</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">img_files</span><span class="p">,</span> <span class="n">target_fi
         <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">img_and_target_path_list</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
         <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">img_and_target_path_list</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
             <span class="k">raise</span> <span class="ne">FileNotFoundError</span><span class="p">(</span><span class="s2">&quot;No target file associated to the images was found&quot;</span><span class="p">)</span>
             <span class="k">raise</span> <span class="ne">FileNotFoundError</span><span class="p">(</span><span class="s2">&quot;No target file associated to the images was found&quot;</span><span class="p">)</span>
 
 
         <span class="n">num_missing_files</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">img_files</span><span class="p">)</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">img_and_target_path_list</span><span class="p">)</span>
         <span class="n">num_missing_files</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">img_files</span><span class="p">)</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">img_and_target_path_list</span><span class="p">)</span>
         <span class="k">if</span> <span class="n">num_missing_files</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">num_missing_files</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
-            <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">num_missing_files</span><span class="si">}</span><span class="s1"> labels files were not loaded our of </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">img_files</span><span class="p">)</span><span class="si">}</span><span class="s1"> 
+            <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">num_missing_files</span><span class="si">}</span><span class="s2"> labels files were not loaded our of </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">img_files</span><span class="p">)</span><span class="si">}</span><span class="s2">
 
 
         <span class="bp">self</span><span class="o">.</span><span class="n">img_and_target_path_list</span> <span class="o">=</span> <span class="n">img_and_target_path_list</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">img_and_target_path_list</span> <span class="o">=</span> <span class="n">img_and_target_path_list</span>
         <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">img_and_target_path_list</span><span class="p">)</span>
         <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">img_and_target_path_list</span><span class="p">)</span>
@@ -156,38 +168,37 @@
 <span class="sd">                    - img_path</span>
 <span class="sd">                    - img_path</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
         <span class="n">img_path</span><span class="p">,</span> <span class="n">target_path</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">img_and_target_path_list</span><span class="p">[</span><span class="n">sample_id</span><span class="p">]</span>
         <span class="n">img_path</span><span class="p">,</span> <span class="n">target_path</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">img_and_target_path_list</span><span class="p">[</span><span class="n">sample_id</span><span class="p">]</span>
-        <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">target_path</span><span class="p">,</span> <span class="s1">&#39;r&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">targets_file</span><span class="p">:</span>
+        <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">target_path</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">targets_file</span><span class="p">:</span>
             <span class="n">target</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">x</span><span class="o">.</span><span class="n">split</span><span class="p">()</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">targets_file</span><span class="o">.</span><span class="n">read</span><span class="p">()</span><span class="o">.</span><span class="n
             <span class="n">target</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">x</span><span class="o">.</span><span class="n">split</span><span class="p">()</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">targets_file</span><span class="o">.</span><span class="n">read</span><span class="p">()</span><span class="o">.</span><span class="n
 
 
-        <span class="n">width</span><span class="p">,</span> <span class="n">height</span> <span class="o">=</span> <span class="n">get_image_size_from_path</span><span class="p">(</span><span class="n">img_path</span><span class="p">)</span>
+        <span class="n">height</span><span class="p">,</span> <span class="n">width</span> <span class="o">=</span> <span class="n">get_image_size_from_path</span><span class="p">(</span><span class="n">img_path</span><span class="p">)</span>
 
 
-        <span class="c1"># We have to rescale the targets because the images will be rescaled.</span>
+        <span class="c1"># We have to rescale the targets because the images will be resized.</span>
         <span class="n">r</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">height</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">]</s
         <span class="n">r</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">height</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">]</s
         <span class="n">target</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">*=</span> <span class="n">r</span>
         <span class="n">target</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">*=</span> <span class="n">r</span>
 
 
-        <span class="n">initial_img_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">width</span><span class="p">,</span> <span class="n">height</span><span class="p">)</span>
-        <span class="n">resized_img_shape</span> <span class="o">=</span> <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">width</span> <span class="o">*</span> <span class="n">r</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">height</span> <span class="o">*</span> <span class="n">r</span><span class="p">))</span>
+        <span class="n">resized_img_shape</span> <span class="o">=</span> <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">height</span> <span class="o">*</span> <span class="n">r</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">width</span> <span class="o">*</span> <span class="n">r</span><span class="p">))</span>
 
 
-        <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;img_path&quot;</span><span class="p">:</span> <span class="n">img_path</span><span class="p">,</span> <span class="s2">&quot;target&quot;</span><span class="p">:</span> <span class="n">target</span><span class="p">,</span>
-                <span class="s2">&quot;initial_img_shape&quot;</span><span class="p">:</span> <span class="n">initial_img_shape</span><span class="p">,</span> <span class="s2">&quot;resized_img_shape&quot;</span><span class="p">:</span> <span class="n">resized_img_shape</span><span class="p">}</span>
+        <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;img_path&quot;</span><span class="p">:</span> <span class="n">img_path</span><span class="p">,</span> <span class="s2">&quot;target&quot;</span><span class="p">:</span> <span class="n">target</span><span class="p">,</span> <span class="s2">&quot;resized_img_shape&quot;</span><span class="p">:</span> <span class="n">resized_img_shape</span><span class="p">}</span>
 
 
-<div class="viewcode-block" id="PascalVOCDetectionDataset.download"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.detection_datasets.html#super_gradients.training.datasets.PascalVOCDetectionDataset.download">[docs]</a>    <span class="nd">@staticmethod</span>
+<div class="viewcode-block" id="PascalVOCDetectionDataset.download"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.PascalVOCDetectionDataset.download">[docs]</a>    <span class="nd">@staticmethod</span>
     <span class="k">def</span> <span class="nf">download</span><span class="p">(</span><span class="n">data_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">download</span><span class="p">(</span><span class="n">data_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;Download Pascal dataset in XYXY_LABEL format.</span>
         <span class="sd">&quot;&quot;&quot;Download Pascal dataset in XYXY_LABEL format.</span>
 
 
 <span class="sd">        Data extracted form http://host.robots.ox.ac.uk/pascal/VOC/</span>
 <span class="sd">        Data extracted form http://host.robots.ox.ac.uk/pascal/VOC/</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
+
         <span class="k">def</span> <span class="nf">_parse_and_save_labels</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">new_label_path</span><span class="p">,</span> <span class="n">year</span><span class="p">,</span> <span class="n">image_id</span><span class="p">):</span>
         <span class="k">def</span> <span class="nf">_parse_and_save_labels</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">new_label_path</span><span class="p">,</span> <span class="n">year</span><span class="p">,</span> <span class="n">image_id</span><span class="p">):</span>
             <span class="sd">&quot;&quot;&quot;Parse and save the labels of an image in XYXY_LABEL format.&quot;&quot;&quot;</span>
             <span class="sd">&quot;&quot;&quot;Parse and save the labels of an image in XYXY_LABEL format.&quot;&quot;&quot;</span>
 
 
-            <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">path</span><span class="si">}</span><span class="s1">/VOC</span><span class="si">{</span><span class="n">year</span><span class="si">}</span><span class="s1">/Annotations/</span><span class="si">{</span><span class="n">image_id</span><span class="si">}</span><span class="s1">.xml&#39;</span><span class="p">)</s
+            <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">path</span><span class="si">}</span><span class="s2">/VOC</span><span class="si">{</span><span class="n">year</span><span class="si">}</span><span class="s2">/Annotations/</span><span class="si">{</span><span class="n">image_id</span><span class="si">}</span><span class="s2">.xml&quot;</span><span class="p">)<
                 <span class="n">xml_parser</span> <span class="o">=</span> <span class="n">ElementTree</span><span class="o">.</span><span class="n">parse</span><span class="p">(</span><span class="n">f</span><span class="p">)</span><span class="o">.</span><span class="n">getroot</span><span class="p">()</span>
                 <span class="n">xml_parser</span> <span class="o">=</span> <span class="n">ElementTree</span><span class="o">.</span><span class="n">parse</span><span class="p">(</span><span class="n">f</span><span class="p">)</span><span class="o">.</span><span class="n">getroot</span><span class="p">()</span>
 
 
             <span class="n">labels</span> <span class="o">=</span> <span class="p">[]</span>
             <span class="n">labels</span> <span class="o">=</span> <span class="p">[]</span>
-            <span class="k">for</span> <span class="n">obj</span> <span class="ow">in</span> <span class="n">xml_parser</span><span class="o">.</span><span class="n">iter</span><span class="p">(</span><span class="s1">&#39;object&#39;</span><span class="p">):</span>
-                <span class="bp">cls</span> <span class="o">=</span> <span class="n">obj</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s1">&#39;name&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">text</span>
-                <span class="k">if</span> <span class="bp">cls</span> <span class="ow">in</span> <span class="n">PASCAL_VOC_2012_CLASSES_LIST</span> <span class="ow">and</span> <span class="ow">not</span> <span class="nb">int</span><span class="p">(</span><span class="n">obj</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s1">&#39;difficult&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">text</span><span class="p">)</span>
-                    <span class="n">xml_box</span> <span class="o">=</span> <span class="n">obj</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s1">&#39;bndbox&#39;</span><span class="p">)</span>
+            <span class="k">for</span> <span class="n">obj</span> <span class="ow">in</span> <span class="n">xml_parser</span><span class="o">.</span><span class="n">iter</span><span class="p">(</span><span class="s2">&quot;object&quot;</span><span class="p">):</span>
+                <span class="bp">cls</span> <span class="o">=</span> <span class="n">obj</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">&quot;name&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">text</span>
+                <span class="k">if</span> <span class="bp">cls</span> <span class="ow">in</span> <span class="n">PASCAL_VOC_2012_CLASSES_LIST</span> <span class="ow">and</span> <span class="ow">not</span> <span class="nb">int</span><span class="p">(</span><span class="n">obj</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">&quot;difficult&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">text</span><span class="p">)</spa
+                    <span class="n">xml_box</span> <span class="o">=</span> <span class="n">obj</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">&quot;bndbox&quot;</span><span class="p">)</span>
 
 
                     <span class="k">def</span> <span class="nf">get_coord</span><span class="p">(</span><span class="n">box_coord</span><span class="p">):</span>
                     <span class="k">def</span> <span class="nf">get_coord</span><span class="p">(</span><span class="n">box_coord</span><span class="p">):</span>
                         <span class="k">return</span> <span class="n">xml_box</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="n">box_coord</span><span class="p">)</span><span class="o">.</span><span class="n">text</span>
                         <span class="k">return</span> <span class="n">xml_box</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="n">box_coord</span><span class="p">)</span><span class="o">.</span><span class="n">text</span>
@@ -195,33 +206,72 @@
                     <span class="n">xmin</span><span class="p">,</span> <span class="n">ymin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">ymax</span> <span class="o">=</span> <span class="n">get_coord</span><span class="p">(</span><span class="s2">&quot;xmin&quot;</span><span class="p">),</span> <span class="n">get_coord</span><span class="p">(</span><span class="s2">&quot;ymin&quot;</span><span class="p">),</span> <span class="n">get_coord<
                     <span class="n">xmin</span><span class="p">,</span> <span class="n">ymin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">ymax</span> <span class="o">=</span> <span class="n">get_coord</span><span class="p">(</span><span class="s2">&quot;xmin&quot;</span><span class="p">),</span> <span class="n">get_coord</span><span class="p">(</span><span class="s2">&quot;ymin&quot;</span><span class="p">),</span> <span class="n">get_coord<
                     <span class="n">labels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s2">&quot; &quot;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="n">xmin</span><span class="p">,</span> <span class="n">ymin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">ymax</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><
                     <span class="n">labels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="s2">&quot; &quot;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="n">xmin</span><span class="p">,</span> <span class="n">ymin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">ymax</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><
 
 
-            <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">new_label_path</span><span class="p">,</span> <span class="s1">&#39;w&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
+            <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">new_label_path</span><span class="p">,</span> <span class="s2">&quot;w&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
                 <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">labels</span><span class="p">))</span>
                 <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">labels</span><span class="p">))</span>
 
 
-        <span class="n">urls</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar&quot;</span><span class="p">,</span>  <span class="c1"># 439M 5011 images</span>
-                <span class="s2">&quot;http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar&quot;</span><span class="p">,</span>  <span class="c1"># 430M, 4952 images</span>
-                <span class="s2">&quot;http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar&quot;</span><span class="p">]</span>  <span class="c1"># 1.86G, 17125 images</span>
+        <span class="n">urls</span> <span class="o">=</span> <span class="p">[</span>
+            <span class="s2">&quot;http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar&quot;</span><span class="p">,</span>  <span class="c1"># 439M 5011 images</span>
+            <span class="s2">&quot;http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar&quot;</span><span class="p">,</span>  <span class="c1"># 430M, 4952 images</span>
+            <span class="s2">&quot;http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar&quot;</span><span class="p">,</span>
+        <span class="p">]</span>  <span class="c1"># 1.86G, 17125 images</span>
         <span class="n">data_dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span>
         <span class="n">data_dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span>
-        <span class="n">download_and_untar_from_url</span><span class="p">(</span><span class="n">urls</span><span class="p">,</span> <span class="nb">dir</span><span class="o">=</span><span class="n">data_dir</span> <span class="o">/</span> <span class="s1">&#39;images&#39;</span><span class="p">)</span>
+        <span class="n">download_and_untar_from_url</span><span class="p">(</span><span class="n">urls</span><span class="p">,</span> <span class="nb">dir</span><span class="o">=</span><span class="n">data_dir</span> <span class="o">/</span> <span class="s2">&quot;images&quot;</span><span class="p">)</span>
 
 
         <span class="c1"># Convert</span>
         <span class="c1"># Convert</span>
-        <span class="n">data_path</span> <span class="o">=</span> <span class="n">data_dir</span> <span class="o">/</span> <span class="s1">&#39;images&#39;</span> <span class="o">/</span> <span class="s1">&#39;VOCdevkit&#39;</span>
-        <span class="k">for</span> <span class="n">year</span><span class="p">,</span> <span class="n">image_set</span> <span class="ow">in</span> <span class="p">(</span><span class="s1">&#39;2012&#39;</span><span class="p">,</span> <span class="s1">&#39;train&#39;</span><span class="p">),</span> <span class="p">(</span><span class="s1">&#39;2012&#39;</span><span class="p">,</span> <span class="s1">&#39;val&#39;</span><span class="p">),</span> <span class="p">(</span><span class="s1">&#39;2007
-            <span class="n">dest_imgs_path</span> <span class="o">=</span> <span class="n">data_dir</span> <span class="o">/</span> <span class="s1">&#39;images&#39;</span> <span class="o">/</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">image_set</span><span class="si">}{</span><span class="n">year</span><span class="si">}</span><span class="s1">&#39;</span>
+        <span class="n">data_path</span> <span class="o">=</span> <span class="n">data_dir</span> <span class="o">/</span> <span class="s2">&quot;images&quot;</span> <span class="o">/</span> <span class="s2">&quot;VOCdevkit&quot;</span>
+        <span class="k">for</span> <span class="n">year</span><span class="p">,</span> <span class="n">image_set</span> <span class="ow">in</span> <span class="p">(</span><span class="s2">&quot;2012&quot;</span><span class="p">,</span> <span class="s2">&quot;train&quot;</span><span class="p">),</span> <span class="p">(</span><span class="s2">&quot;2012&quot;</span><span class="p">,</span> <span class="s2">&quot;val&quot;</span><span class="p">),</span> <span class="p">(</span><span class="s2">&
+            <span class="n">dest_imgs_path</span> <span class="o">=</span> <span class="n">data_dir</span> <span class="o">/</span> <span class="s2">&quot;images&quot;</span> <span class="o">/</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">image_set</span><span class="si">}{</span><span class="n">year</span><span class="si">}</span><span class="s2">&quot;</span>
             <span class="n">dest_imgs_path</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
             <span class="n">dest_imgs_path</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
 
 
-            <span class="n">dest_labels_path</span> <span class="o">=</span> <span class="n">data_dir</span> <span class="o">/</span> <span class="s1">&#39;labels&#39;</span> <span class="o">/</span> <span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">image_set</span><span class="si">}{</span><span class="n">year</span><span class="si">}</span><span class="s1">&#39;</span>
+            <span class="n">dest_labels_path</span> <span class="o">=</span> <span class="n">data_dir</span> <span class="o">/</span> <span class="s2">&quot;labels&quot;</span> <span class="o">/</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">image_set</span><span class="si">}{</span><span class="n">year</span><span class="si">}</span><span class="s2">&quot;</span>
             <span class="n">dest_labels_path</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
             <span class="n">dest_labels_path</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
 
 
-            <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">data_path</span> <span class="o">/</span> <span class="sa">f</span><span class="s1">&#39;VOC</span><span class="si">{</span><span class="n">year</span><span class="si">}</span><span class="s1">/ImageSets/Main/</span><span class="si">{</span><span class="n">image_set</span><span class="si">}</span><span class="s1">.txt&#39;</span><span class="p">)</span> <span class="k">as</span> <span cla
+            <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">data_path</span> <span class="o">/</span> <span class="sa">f</span><span class="s2">&quot;VOC</span><span class="si">{</span><span class="n">year</span><span class="si">}</span><span class="s2">/ImageSets/Main/</span><span class="si">{</span><span class="n">image_set</span><span class="si">}</span><span class="s2">.txt&quot;</span><span class="p">)</span> <span class="k">as</span> <span c
                 <span class="n">image_ids</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">()</span><span class="o">.</span><span class="n">strip</span><span class="p">()</span><span class="o">.</span><span class="n">split</span><span class="p">()</span>
                 <span class="n">image_ids</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">()</span><span class="o">.</span><span class="n">strip</span><span class="p">()</span><span class="o">.</span><span class="n">split</span><span class="p">()</span>
 
 
-            <span class="k">for</span> <span class="nb">id</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">image_ids</span><span class="p">,</span> <span class="n">desc</span><span class="o">=</span><span class="sa">f</span><span class="s1">&#39;</span><span class="si">{</span><span class="n">image_set</span><span class="si">}{</span><span class="n">year</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">):</span>
-                <span class="n">img_path</span> <span class="o">=</span> <span class="n">data_path</span> <span class="o">/</span> <span class="sa">f</span><span class="s1">&#39;VOC</span><span class="si">{</span><span class="n">year</span><span class="si">}</span><span class="s1">/JPEGImages/</span><span class="si">{</span><span class="nb">id</span><span class="si">}</span><span class="s1">.jpg&#39;</span>
+            <span class="k">for</span> <span class="nb">id</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">image_ids</span><span class="p">,</span> <span class="n">desc</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">image_set</span><span class="si">}{</span><span class="n">year</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">):</span>
+                <span class="n">img_path</span> <span class="o">=</span> <span class="n">data_path</span> <span class="o">/</span> <span class="sa">f</span><span class="s2">&quot;VOC</span><span class="si">{</span><span class="n">year</span><span class="si">}</span><span class="s2">/JPEGImages/</span><span class="si">{</span><span class="nb">id</span><span class="si">}</span><span class="s2">.jpg&quot;</span>
                 <span class="n">new_img_path</span> <span class="o">=</span> <span class="n">dest_imgs_path</span> <span class="o">/</span> <span class="n">img_path</span><span class="o">.</span><span class="n">name</span>
                 <span class="n">new_img_path</span> <span class="o">=</span> <span class="n">dest_imgs_path</span> <span class="o">/</span> <span class="n">img_path</span><span class="o">.</span><span class="n">name</span>
-                <span class="n">new_label_path</span> <span class="o">=</span> <span class="p">(</span><span class="n">dest_labels_path</span> <span class="o">/</span> <span class="n">img_path</span><span class="o">.</span><span class="n">name</span><span class="p">)</span><span class="o">.</span><span class="n">with_suffix</span><span class="p">(</span><span class="s1">&#39;.txt&#39;</span><span class="p">)</span>
+                <span class="n">new_label_path</span> <span class="o">=</span> <span class="p">(</span><span class="n">dest_labels_path</span> <span class="o">/</span> <span class="n">img_path</span><span class="o">.</span><span class="n">name</span><span class="p">)</span><span class="o">.</span><span class="n">with_suffix</span><span class="p">(</span><span class="s2">&quot;.txt&quot;</span><span class="p">)</span>
                 <span class="n">img_path</span><span class="o">.</span><span class="n">rename</span><span class="p">(</span><span class="n">new_img_path</span><span class="p">)</span>  <span class="c1"># Move image to dest folder</span>
                 <span class="n">img_path</span><span class="o">.</span><span class="n">rename</span><span class="p">(</span><span class="n">new_img_path</span><span class="p">)</span>  <span class="c1"># Move image to dest folder</span>
                 <span class="n">_parse_and_save_labels</span><span class="p">(</span><span class="n">data_path</span><span class="p">,</span> <span class="n">new_label_path</span><span class="p">,</span> <span class="n">year</span><span class="p">,</span> <span class="nb">id</span><span class="p">)</span></div></div>
                 <span class="n">_parse_and_save_labels</span><span class="p">(</span><span class="n">data_path</span><span class="p">,</span> <span class="n">new_label_path</span><span class="p">,</span> <span class="n">year</span><span class="p">,</span> <span class="nb">id</span><span class="p">)</span></div></div>
+
+
+<span class="k">class</span> <span class="nc">PascalVOCUnifiedDetectionTrainDataset</span><span class="p">(</span><span class="n">ConcatDataset</span><span class="p">):</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">data_dir</span><span class="p">,</span>
+        <span class="n">input_dim</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span>
+        <span class="n">cache</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+        <span class="n">cache_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">transforms</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">DetectionTransform</span><span class="p">]</span> <span class="o">=</span> <span class="p">[],</span>
+        <span class="n">class_inclusion_list</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">max_num_samples</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">download</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+    <span class="p">):</span>
+        <span class="k">if</span> <span class="n">download</span><span class="p">:</span>
+            <span class="n">PascalVOCDetectionDataset</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">data_dir</span><span class="o">=</span><span class="n">data_dir</span><span class="p">)</span>
+
+        <span class="n">train_dataset_names</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;train2007&quot;</span><span class="p">,</span> <span class="s2">&quot;val2007&quot;</span><span class="p">,</span> <span class="s2">&quot;train2012&quot;</span><span class="p">,</span> <span class="s2">&quot;val2012&quot;</span><span class="p">]</span>
+        <span class="c1"># We divide train_max_num_samples between the datasets</span>
+        <span class="k">if</span> <span class="n">max_num_samples</span><span class="p">:</span>
+            <span class="n">max_num_samples_per_train_dataset</span> <span class="o">=</span> <span class="p">[</span><span class="nb">len</span><span class="p">(</span><span class="n">segment</span><span class="p">)</span> <span class="k">for</span> <span class="n">segment</span> <span class="ow">in</span> <span class="n">np</span><span class="o">.</span><span class="n">array_split</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">max_num_sample
+        <span class="k">else</span><span class="p">:</span>
+            <span class="n">max_num_samples_per_train_dataset</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_dataset_names</span><span class="p">)</span>
+        <span class="n">train_sets</span> <span class="o">=</span> <span class="p">[</span>
+            <span class="n">PascalVOCDetectionDataset</span><span class="p">(</span>
+                <span class="n">data_dir</span><span class="o">=</span><span class="n">data_dir</span><span class="p">,</span>
+                <span class="n">input_dim</span><span class="o">=</span><span class="n">input_dim</span><span class="p">,</span>
+                <span class="n">cache</span><span class="o">=</span><span class="n">cache</span><span class="p">,</span>
+                <span class="n">cache_dir</span><span class="o">=</span><span class="n">cache_dir</span><span class="p">,</span>
+                <span class="n">transforms</span><span class="o">=</span><span class="n">transforms</span><span class="p">,</span>
+                <span class="n">images_sub_directory</span><span class="o">=</span><span class="s2">&quot;images/&quot;</span> <span class="o">+</span> <span class="n">trainset_name</span> <span class="o">+</span> <span class="s2">&quot;/&quot;</span><span class="p">,</span>
+                <span class="n">class_inclusion_list</span><span class="o">=</span><span class="n">class_inclusion_list</span><span class="p">,</span>
+                <span class="n">max_num_samples</span><span class="o">=</span><span class="n">max_num_samples_per_train_dataset</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
+            <span class="p">)</span>
+            <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">trainset_name</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">train_dataset_names</span><span class="p">)</span>
+        <span class="p">]</span>
+        <span class="nb">super</span><span class="p">(</span><span class="n">PascalVOCUnifiedDetectionTrainDataset</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">train_sets</span><span class="p">)</span>
 </pre></div>
 </pre></div>
 
 
            </div>
            </div>
@@ -251,4 +301,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.datasets.mixup &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.datasets.mixup</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.datasets.mixup</h1><div class="highlight"><pre>
  83. <span></span><span class="sd">&quot;&quot;&quot; Mixup and Cutmix</span>
  84. <span class="sd">Papers:</span>
  85. <span class="sd">mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)</span>
  86. <span class="sd">CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)</span>
  87. <span class="sd">Code Reference:</span>
  88. <span class="sd">CutMix: https://github.com/clovaai/CutMix-PyTorch</span>
  89. <span class="sd">CutMix by timm: https://github.com/rwightman/pytorch-image-models/timm</span>
  90. <span class="sd">&quot;&quot;&quot;</span>
  91. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Union</span>
  92. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  93. <span class="kn">import</span> <span class="nn">torch</span>
  94. <span class="kn">from</span> <span class="nn">super_gradients.training.exceptions.dataset_exceptions</span> <span class="kn">import</span> <span class="n">IllegalDatasetParameterException</span>
  95. <div class="viewcode-block" id="one_hot"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.mixup.one_hot">[docs]</a><span class="k">def</span> <span class="nf">one_hot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">,</span> <span class="n">on_value</span><span class="o">=</span><span class="mf">1.</span><span class="p">,</span> <span class="n">off_value</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cuda&#39;</span><span class="p">):</span>
  96. <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">long</span><span class="p">()</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
  97. <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">],</span> <span class="n">num_classes</span><span class="p">),</span> <span class="n">off_value</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">on_value</span><span class="p">)</span></div>
  98. <div class="viewcode-block" id="mixup_target"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.mixup.mixup_target">[docs]</a><span class="k">def</span> <span class="nf">mixup_target</span><span class="p">(</span><span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">lam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.</span><span class="p">,</span> <span class="n">smoothing</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;cuda&#39;</span><span class="p">):</span>
  99. <span class="sd">&quot;&quot;&quot;</span>
  100. <span class="sd"> generate a smooth target (label) two-hot tensor to support the mixed images with different labels</span>
  101. <span class="sd"> :param target: the targets tensor</span>
  102. <span class="sd"> :param num_classes: number of classes (to set the final tensor size)</span>
  103. <span class="sd"> :param lam: percentage of label a range [0, 1] in the mixing</span>
  104. <span class="sd"> :param smoothing: the smoothing multiplier</span>
  105. <span class="sd"> :param device: usable device [&#39;cuda&#39;, &#39;cpu&#39;]</span>
  106. <span class="sd"> :return:</span>
  107. <span class="sd"> &quot;&quot;&quot;</span>
  108. <span class="n">off_value</span> <span class="o">=</span> <span class="n">smoothing</span> <span class="o">/</span> <span class="n">num_classes</span>
  109. <span class="n">on_value</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">-</span> <span class="n">smoothing</span> <span class="o">+</span> <span class="n">off_value</span>
  110. <span class="n">y1</span> <span class="o">=</span> <span class="n">one_hot</span><span class="p">(</span><span class="n">target</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">,</span> <span class="n">on_value</span><span class="o">=</span><span class="n">on_value</span><span class="p">,</span> <span class="n">off_value</span><span class="o">=</span><span class="n">off_value</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
  111. <span class="n">y2</span> <span class="o">=</span> <span class="n">one_hot</span><span class="p">(</span><span class="n">target</span><span class="o">.</span><span class="n">flip</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">num_classes</span><span class="p">,</span> <span class="n">on_value</span><span class="o">=</span><span class="n">on_value</span><span class="p">,</span> <span class="n">off_value</span><span class="o">=</span><span class="n">off_value</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
  112. <span class="k">return</span> <span class="n">y1</span> <span class="o">*</span> <span class="n">lam</span> <span class="o">+</span> <span class="n">y2</span> <span class="o">*</span> <span class="p">(</span><span class="mf">1.</span> <span class="o">-</span> <span class="n">lam</span><span class="p">)</span></div>
  113. <div class="viewcode-block" id="rand_bbox"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.mixup.rand_bbox">[docs]</a><span class="k">def</span> <span class="nf">rand_bbox</span><span class="p">(</span><span class="n">img_shape</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span> <span class="n">lam</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">margin</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.</span><span class="p">,</span> <span class="n">count</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  114. <span class="sd">&quot;&quot;&quot; Standard CutMix bounding-box</span>
  115. <span class="sd"> Generates a random square bbox based on lambda value. This impl includes</span>
  116. <span class="sd"> support for enforcing a border margin as percent of bbox dimensions.</span>
  117. <span class="sd"> :param img_shape: Image shape as tuple</span>
  118. <span class="sd"> :param lam: Cutmix lambda value</span>
  119. <span class="sd"> :param margin: Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)</span>
  120. <span class="sd"> :param count: Number of bbox to generate</span>
  121. <span class="sd"> &quot;&quot;&quot;</span>
  122. <span class="n">ratio</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">lam</span><span class="p">)</span>
  123. <span class="n">img_h</span><span class="p">,</span> <span class="n">img_w</span> <span class="o">=</span> <span class="n">img_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">:]</span>
  124. <span class="n">cut_h</span><span class="p">,</span> <span class="n">cut_w</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">img_h</span> <span class="o">*</span> <span class="n">ratio</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">img_w</span> <span class="o">*</span> <span class="n">ratio</span><span class="p">)</span>
  125. <span class="n">margin_y</span><span class="p">,</span> <span class="n">margin_x</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">margin</span> <span class="o">*</span> <span class="n">cut_h</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">margin</span> <span class="o">*</span> <span class="n">cut_w</span><span class="p">)</span>
  126. <span class="n">cy</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span> <span class="o">+</span> <span class="n">margin_y</span><span class="p">,</span> <span class="n">img_h</span> <span class="o">-</span> <span class="n">margin_y</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">count</span><span class="p">)</span>
  127. <span class="n">cx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span> <span class="o">+</span> <span class="n">margin_x</span><span class="p">,</span> <span class="n">img_w</span> <span class="o">-</span> <span class="n">margin_x</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">count</span><span class="p">)</span>
  128. <span class="n">yl</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">cy</span> <span class="o">-</span> <span class="n">cut_h</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">img_h</span><span class="p">)</span>
  129. <span class="n">yh</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">cy</span> <span class="o">+</span> <span class="n">cut_h</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">img_h</span><span class="p">)</span>
  130. <span class="n">xl</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">cx</span> <span class="o">-</span> <span class="n">cut_w</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">img_w</span><span class="p">)</span>
  131. <span class="n">xh</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">cx</span> <span class="o">+</span> <span class="n">cut_w</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">img_w</span><span class="p">)</span>
  132. <span class="k">return</span> <span class="n">yl</span><span class="p">,</span> <span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">,</span> <span class="n">xh</span></div>
  133. <div class="viewcode-block" id="rand_bbox_minmax"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.mixup.rand_bbox_minmax">[docs]</a><span class="k">def</span> <span class="nf">rand_bbox_minmax</span><span class="p">(</span><span class="n">img_shape</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span> <span class="n">minmax</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">],</span> <span class="n">count</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  134. <span class="sd">&quot;&quot;&quot; Min-Max CutMix bounding-box</span>
  135. <span class="sd"> Inspired by Darknet cutmix impl, generates a random rectangular bbox</span>
  136. <span class="sd"> based on min/max percent values applied to each dimension of the input image.</span>
  137. <span class="sd"> Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.</span>
  138. <span class="sd"> :param img_shape: Image shape as tuple</span>
  139. <span class="sd"> :param minmax: Min and max bbox ratios (as percent of image size)</span>
  140. <span class="sd"> :param count: Number of bbox to generate</span>
  141. <span class="sd"> &quot;&quot;&quot;</span>
  142. <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">minmax</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span>
  143. <span class="n">img_h</span><span class="p">,</span> <span class="n">img_w</span> <span class="o">=</span> <span class="n">img_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">:]</span>
  144. <span class="n">cut_h</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">img_h</span> <span class="o">*</span> <span class="n">minmax</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="nb">int</span><span class="p">(</span><span class="n">img_h</span> <span class="o">*</span> <span class="n">minmax</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span> <span class="n">size</span><span class="o">=</span><span class="n">count</span><span class="p">)</span>
  145. <span class="n">cut_w</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">img_w</span> <span class="o">*</span> <span class="n">minmax</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="nb">int</span><span class="p">(</span><span class="n">img_w</span> <span class="o">*</span> <span class="n">minmax</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span> <span class="n">size</span><span class="o">=</span><span class="n">count</span><span class="p">)</span>
  146. <span class="n">yl</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">img_h</span> <span class="o">-</span> <span class="n">cut_h</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">count</span><span class="p">)</span>
  147. <span class="n">xl</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">img_w</span> <span class="o">-</span> <span class="n">cut_w</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">count</span><span class="p">)</span>
  148. <span class="n">yu</span> <span class="o">=</span> <span class="n">yl</span> <span class="o">+</span> <span class="n">cut_h</span>
  149. <span class="n">xu</span> <span class="o">=</span> <span class="n">xl</span> <span class="o">+</span> <span class="n">cut_w</span>
  150. <span class="k">return</span> <span class="n">yl</span><span class="p">,</span> <span class="n">yu</span><span class="p">,</span> <span class="n">xl</span><span class="p">,</span> <span class="n">xu</span></div>
  151. <div class="viewcode-block" id="cutmix_bbox_and_lam"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.mixup.cutmix_bbox_and_lam">[docs]</a><span class="k">def</span> <span class="nf">cutmix_bbox_and_lam</span><span class="p">(</span><span class="n">img_shape</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span> <span class="n">lam</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">ratio_minmax</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">correct_lam</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
  152. <span class="n">count</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  153. <span class="sd">&quot;&quot;&quot;</span>
  154. <span class="sd"> Generate bbox and apply lambda correction.</span>
  155. <span class="sd"> &quot;&quot;&quot;</span>
  156. <span class="k">if</span> <span class="n">ratio_minmax</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  157. <span class="n">yl</span><span class="p">,</span> <span class="n">yu</span><span class="p">,</span> <span class="n">xl</span><span class="p">,</span> <span class="n">xu</span> <span class="o">=</span> <span class="n">rand_bbox_minmax</span><span class="p">(</span><span class="n">img_shape</span><span class="p">,</span> <span class="n">ratio_minmax</span><span class="p">,</span> <span class="n">count</span><span class="o">=</span><span class="n">count</span><span class="p">)</span>
  158. <span class="k">else</span><span class="p">:</span>
  159. <span class="n">yl</span><span class="p">,</span> <span class="n">yu</span><span class="p">,</span> <span class="n">xl</span><span class="p">,</span> <span class="n">xu</span> <span class="o">=</span> <span class="n">rand_bbox</span><span class="p">(</span><span class="n">img_shape</span><span class="p">,</span> <span class="n">lam</span><span class="p">,</span> <span class="n">count</span><span class="o">=</span><span class="n">count</span><span class="p">)</span>
  160. <span class="k">if</span> <span class="n">correct_lam</span> <span class="ow">or</span> <span class="n">ratio_minmax</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  161. <span class="n">bbox_area</span> <span class="o">=</span> <span class="p">(</span><span class="n">yu</span> <span class="o">-</span> <span class="n">yl</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">xu</span> <span class="o">-</span> <span class="n">xl</span><span class="p">)</span>
  162. <span class="n">lam</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">-</span> <span class="n">bbox_area</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="n">img_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="n">img_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
  163. <span class="k">return</span> <span class="p">(</span><span class="n">yl</span><span class="p">,</span> <span class="n">yu</span><span class="p">,</span> <span class="n">xl</span><span class="p">,</span> <span class="n">xu</span><span class="p">),</span> <span class="n">lam</span></div>
  164. <div class="viewcode-block" id="CollateMixup"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.mixup.CollateMixup">[docs]</a><span class="k">class</span> <span class="nc">CollateMixup</span><span class="p">:</span>
  165. <span class="sd">&quot;&quot;&quot;</span>
  166. <span class="sd"> Collate with Mixup/Cutmix that applies different params to each element or whole batch</span>
  167. <span class="sd"> A Mixup impl that&#39;s performed while collating the batches.</span>
  168. <span class="sd"> &quot;&quot;&quot;</span>
  169. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mixup_alpha</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.</span><span class="p">,</span> <span class="n">cutmix_alpha</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.</span><span class="p">,</span> <span class="n">cutmix_minmax</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
  170. <span class="n">prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span> <span class="n">switch_prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
  171. <span class="n">mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;batch&#39;</span><span class="p">,</span> <span class="n">correct_lam</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">label_smoothing</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1000</span><span class="p">):</span>
  172. <span class="sd">&quot;&quot;&quot;</span>
  173. <span class="sd"> Mixup/Cutmix that applies different params to each element or whole batch</span>
  174. <span class="sd"> :param mixup_alpha: mixup alpha value, mixup is active if &gt; 0.</span>
  175. <span class="sd"> :param cutmix_alpha: cutmix alpha value, cutmix is active if &gt; 0.</span>
  176. <span class="sd"> :param cutmix_minmax: cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.</span>
  177. <span class="sd"> :param prob: probability of applying mixup or cutmix per batch or element</span>
  178. <span class="sd"> :param switch_prob: probability of switching to cutmix instead of mixup when both are active</span>
  179. <span class="sd"> :param mode: how to apply mixup/cutmix params (per &#39;batch&#39;, &#39;pair&#39; (pair of elements), &#39;elem&#39; (element)</span>
  180. <span class="sd"> :param correct_lam: apply lambda correction when cutmix bbox clipped by image borders</span>
  181. <span class="sd"> :param label_smoothing: apply label smoothing to the mixed target tensor</span>
  182. <span class="sd"> :param num_classes: number of classes for target</span>
  183. <span class="sd"> &quot;&quot;&quot;</span>
  184. <span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span> <span class="o">=</span> <span class="n">mixup_alpha</span>
  185. <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span> <span class="o">=</span> <span class="n">cutmix_alpha</span>
  186. <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_minmax</span> <span class="o">=</span> <span class="n">cutmix_minmax</span>
  187. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_minmax</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  188. <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cutmix_minmax</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span>
  189. <span class="c1"># force cutmix alpha == 1.0 when minmax active to keep logic simple &amp; safe</span>
  190. <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span> <span class="o">=</span> <span class="mf">1.0</span>
  191. <span class="bp">self</span><span class="o">.</span><span class="n">mix_prob</span> <span class="o">=</span> <span class="n">prob</span>
  192. <span class="bp">self</span><span class="o">.</span><span class="n">switch_prob</span> <span class="o">=</span> <span class="n">switch_prob</span>
  193. <span class="bp">self</span><span class="o">.</span><span class="n">label_smoothing</span> <span class="o">=</span> <span class="n">label_smoothing</span>
  194. <span class="bp">self</span><span class="o">.</span><span class="n">num_classes</span> <span class="o">=</span> <span class="n">num_classes</span>
  195. <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">=</span> <span class="n">mode</span>
  196. <span class="bp">self</span><span class="o">.</span><span class="n">correct_lam</span> <span class="o">=</span> <span class="n">correct_lam</span> <span class="c1"># correct lambda based on clipped area for cutmix</span>
  197. <span class="bp">self</span><span class="o">.</span><span class="n">mixup_enabled</span> <span class="o">=</span> <span class="kc">True</span> <span class="c1"># set to false to disable mixing (intended tp be set by train loop)</span>
  198. <span class="k">def</span> <span class="nf">_params_per_elem</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">):</span>
  199. <span class="sd">&quot;&quot;&quot;</span>
  200. <span class="sd"> generate two random masks to define which elements of the batch will be mixed and how (depending on the</span>
  201. <span class="sd"> self.mixup_enabled, self.mixup_alpha, self.cutmix_alpha parameters</span>
  202. <span class="sd"> :param batch_size:</span>
  203. <span class="sd"> :return: two tensors with shape=batch_size - the first contains the lambda value per batch element</span>
  204. <span class="sd"> and the second is a binary flag indicating use of cutmix per batch element</span>
  205. <span class="sd"> &quot;&quot;&quot;</span>
  206. <span class="n">lam</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
  207. <span class="n">use_cutmix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">)</span>
  208. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixup_enabled</span><span class="p">:</span>
  209. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span> <span class="o">&gt;</span> <span class="mf">0.</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span> <span class="o">&gt;</span> <span class="mf">0.</span><span class="p">:</span>
  210. <span class="n">use_cutmix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">switch_prob</span>
  211. <span class="n">lam_mix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span>
  212. <span class="n">use_cutmix</span><span class="p">,</span>
  213. <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">beta</span><span class="o">.</span><span class="n">Beta</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span><span class="p">)</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">sample_shape</span><span class="o">=</span><span class="n">batch_size</span><span class="p">),</span>
  214. <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">beta</span><span class="o">.</span><span class="n">Beta</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span><span class="p">)</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">sample_shape</span><span class="o">=</span><span class="n">batch_size</span><span class="p">))</span>
  215. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span> <span class="o">&gt;</span> <span class="mf">0.</span><span class="p">:</span>
  216. <span class="n">lam_mix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">beta</span><span class="o">.</span><span class="n">Beta</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span><span class="p">)</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">sample_shape</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span>
  217. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span> <span class="o">&gt;</span> <span class="mf">0.</span><span class="p">:</span>
  218. <span class="n">use_cutmix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">)</span>
  219. <span class="n">lam_mix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">beta</span><span class="o">.</span><span class="n">Beta</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span><span class="p">)</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">sample_shape</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span>
  220. <span class="k">else</span><span class="p">:</span>
  221. <span class="k">raise</span> <span class="n">IllegalDatasetParameterException</span><span class="p">(</span><span class="s2">&quot;One of mixup_alpha &gt; 0., cutmix_alpha &gt; 0., &quot;</span>
  222. <span class="s2">&quot;cutmix_minmax not None should be true.&quot;</span><span class="p">)</span>
  223. <span class="n">lam</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">mix_prob</span><span class="p">,</span> <span class="n">lam_mix</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span> <span class="n">lam</span><span class="p">)</span>
  224. <span class="k">return</span> <span class="n">lam</span><span class="p">,</span> <span class="n">use_cutmix</span>
  225. <span class="k">def</span> <span class="nf">_params_per_batch</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  226. <span class="sd">&quot;&quot;&quot;</span>
  227. <span class="sd"> generate two random parameters to define if batch will be mixed and how (depending on the</span>
  228. <span class="sd"> self.mixup_enabled, self.mixup_alpha, self.cutmix_alpha parameters</span>
  229. <span class="sd"> :return: two parameters - the first contains the lambda value for the whole batch</span>
  230. <span class="sd"> and the second is a binary flag indicating use of cutmix for the batch</span>
  231. <span class="sd"> &quot;&quot;&quot;</span>
  232. <span class="n">lam</span> <span class="o">=</span> <span class="mf">1.</span>
  233. <span class="n">use_cutmix</span> <span class="o">=</span> <span class="kc">False</span>
  234. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixup_enabled</span> <span class="ow">and</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">mix_prob</span><span class="p">:</span>
  235. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span> <span class="o">&gt;</span> <span class="mf">0.</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span> <span class="o">&gt;</span> <span class="mf">0.</span><span class="p">:</span>
  236. <span class="n">use_cutmix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">switch_prob</span>
  237. <span class="n">lam_mix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">beta</span><span class="o">.</span><span class="n">Beta</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span><span class="p">)</span><span class="o">.</span><span class="n">sample</span><span class="p">()</span> <span class="k">if</span> <span class="n">use_cutmix</span> <span class="k">else</span> \
  238. <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">beta</span><span class="o">.</span><span class="n">Beta</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span><span class="p">)</span><span class="o">.</span><span class="n">sample</span><span class="p">()</span>
  239. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span> <span class="o">&gt;</span> <span class="mf">0.</span><span class="p">:</span>
  240. <span class="n">lam_mix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">beta</span><span class="o">.</span><span class="n">Beta</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mixup_alpha</span><span class="p">)</span><span class="o">.</span><span class="n">sample</span><span class="p">()</span>
  241. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span> <span class="o">&gt;</span> <span class="mf">0.</span><span class="p">:</span>
  242. <span class="n">use_cutmix</span> <span class="o">=</span> <span class="kc">True</span>
  243. <span class="n">lam_mix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">beta</span><span class="o">.</span><span class="n">Beta</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">cutmix_alpha</span><span class="p">)</span><span class="o">.</span><span class="n">sample</span><span class="p">()</span>
  244. <span class="k">else</span><span class="p">:</span>
  245. <span class="k">raise</span> <span class="n">IllegalDatasetParameterException</span><span class="p">(</span><span class="s2">&quot;One of mixup_alpha &gt; 0., cutmix_alpha &gt; 0., &quot;</span>
  246. <span class="s2">&quot;cutmix_minmax not None should be true.&quot;</span><span class="p">)</span>
  247. <span class="n">lam</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="n">lam_mix</span><span class="p">)</span>
  248. <span class="k">return</span> <span class="n">lam</span><span class="p">,</span> <span class="n">use_cutmix</span>
  249. <span class="k">def</span> <span class="nf">_mix_elem_collate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">half</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
  250. <span class="sd">&quot;&quot;&quot;</span>
  251. <span class="sd"> This is the implementation for &#39;elem&#39; or &#39;half&#39; modes</span>
  252. <span class="sd"> :param output: the output tensor to fill</span>
  253. <span class="sd"> :param batch: list of thr batch items</span>
  254. <span class="sd"> :return: a tensor containing the lambda values used for the mixing (this vector can be used for</span>
  255. <span class="sd"> mixing the labels as well)</span>
  256. <span class="sd"> &quot;&quot;&quot;</span>
  257. <span class="n">batch_size</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
  258. <span class="n">num_elem</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="o">//</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">half</span> <span class="k">else</span> <span class="n">batch_size</span>
  259. <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">output</span><span class="p">)</span> <span class="o">==</span> <span class="n">num_elem</span>
  260. <span class="n">lam_batch</span><span class="p">,</span> <span class="n">use_cutmix</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_params_per_elem</span><span class="p">(</span><span class="n">num_elem</span><span class="p">)</span>
  261. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_elem</span><span class="p">):</span>
  262. <span class="n">j</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="o">-</span> <span class="n">i</span> <span class="o">-</span> <span class="mi">1</span>
  263. <span class="n">lam</span> <span class="o">=</span> <span class="n">lam_batch</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
  264. <span class="n">mixed</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
  265. <span class="k">if</span> <span class="n">lam</span> <span class="o">!=</span> <span class="mf">1.</span><span class="p">:</span>
  266. <span class="k">if</span> <span class="n">use_cutmix</span><span class="p">[</span><span class="n">i</span><span class="p">]:</span>
  267. <span class="k">if</span> <span class="ow">not</span> <span class="n">half</span><span class="p">:</span>
  268. <span class="n">mixed</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clone</span><span class="p">(</span><span class="n">mixed</span><span class="p">)</span>
  269. <span class="p">(</span><span class="n">yl</span><span class="p">,</span> <span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">,</span> <span class="n">xh</span><span class="p">),</span> <span class="n">lam</span> <span class="o">=</span> <span class="n">cutmix_bbox_and_lam</span><span class="p">(</span>
  270. <span class="n">output</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">lam</span><span class="p">,</span> <span class="n">ratio_minmax</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">cutmix_minmax</span><span class="p">,</span> <span class="n">correct_lam</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">correct_lam</span><span class="p">)</span>
  271. <span class="n">mixed</span><span class="p">[:,</span> <span class="n">yl</span><span class="p">:</span><span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">:</span><span class="n">xh</span><span class="p">]</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">0</span><span class="p">][:,</span> <span class="n">yl</span><span class="p">:</span><span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">:</span><span class="n">xh</span><span class="p">]</span>
  272. <span class="n">lam_batch</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">lam</span>
  273. <span class="k">else</span><span class="p">:</span>
  274. <span class="n">mixed</span> <span class="o">=</span> <span class="n">mixed</span> <span class="o">*</span> <span class="n">lam</span> <span class="o">+</span> <span class="n">batch</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">lam</span><span class="p">)</span>
  275. <span class="n">output</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+=</span> <span class="n">mixed</span>
  276. <span class="k">if</span> <span class="n">half</span><span class="p">:</span>
  277. <span class="n">lam_batch</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">lam_batch</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">num_elem</span><span class="p">)))</span>
  278. <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">lam_batch</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
  279. <span class="k">def</span> <span class="nf">_mix_pair_collate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="nb">list</span><span class="p">):</span>
  280. <span class="sd">&quot;&quot;&quot;</span>
  281. <span class="sd"> This is the implementation for &#39;pair&#39; mode</span>
  282. <span class="sd"> :param output: the output tensor to fill</span>
  283. <span class="sd"> :param batch: list of thr batch items</span>
  284. <span class="sd"> :return: a tensor containing the lambda values used for the mixing (this vector can be used for</span>
  285. <span class="sd"> mixing the labels as well)</span>
  286. <span class="sd"> &quot;&quot;&quot;</span>
  287. <span class="n">batch_size</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
  288. <span class="n">lam_batch</span><span class="p">,</span> <span class="n">use_cutmix</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_params_per_elem</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">//</span> <span class="mi">2</span><span class="p">)</span>
  289. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">//</span> <span class="mi">2</span><span class="p">):</span>
  290. <span class="n">j</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="o">-</span> <span class="n">i</span> <span class="o">-</span> <span class="mi">1</span>
  291. <span class="n">lam</span> <span class="o">=</span> <span class="n">lam_batch</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
  292. <span class="n">mixed_i</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
  293. <span class="n">mixed_j</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
  294. <span class="k">assert</span> <span class="mi">0</span> <span class="o">&lt;=</span> <span class="n">lam</span> <span class="o">&lt;=</span> <span class="mf">1.0</span>
  295. <span class="k">if</span> <span class="n">lam</span> <span class="o">&lt;</span> <span class="mf">1.</span><span class="p">:</span>
  296. <span class="k">if</span> <span class="n">use_cutmix</span><span class="p">[</span><span class="n">i</span><span class="p">]:</span>
  297. <span class="p">(</span><span class="n">yl</span><span class="p">,</span> <span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">,</span> <span class="n">xh</span><span class="p">),</span> <span class="n">lam</span> <span class="o">=</span> <span class="n">cutmix_bbox_and_lam</span><span class="p">(</span>
  298. <span class="n">output</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">lam</span><span class="p">,</span> <span class="n">ratio_minmax</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">cutmix_minmax</span><span class="p">,</span> <span class="n">correct_lam</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">correct_lam</span><span class="p">)</span>
  299. <span class="n">patch_i</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clone</span><span class="p">(</span><span class="n">mixed_i</span><span class="p">[:,</span> <span class="n">yl</span><span class="p">:</span><span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">:</span><span class="n">xh</span><span class="p">])</span>
  300. <span class="n">mixed_i</span><span class="p">[:,</span> <span class="n">yl</span><span class="p">:</span><span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">:</span><span class="n">xh</span><span class="p">]</span> <span class="o">=</span> <span class="n">mixed_j</span><span class="p">[:,</span> <span class="n">yl</span><span class="p">:</span><span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">:</span><span class="n">xh</span><span class="p">]</span>
  301. <span class="n">mixed_j</span><span class="p">[:,</span> <span class="n">yl</span><span class="p">:</span><span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">:</span><span class="n">xh</span><span class="p">]</span> <span class="o">=</span> <span class="n">patch_i</span>
  302. <span class="n">lam_batch</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">lam</span>
  303. <span class="k">else</span><span class="p">:</span>
  304. <span class="n">mixed_temp</span> <span class="o">=</span> <span class="n">mixed_i</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">*</span> <span class="n">lam</span> <span class="o">+</span> <span class="n">mixed_j</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">lam</span><span class="p">)</span>
  305. <span class="n">mixed_j</span> <span class="o">=</span> <span class="n">mixed_j</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">*</span> <span class="n">lam</span> <span class="o">+</span> <span class="n">mixed_i</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">lam</span><span class="p">)</span>
  306. <span class="n">mixed_i</span> <span class="o">=</span> <span class="n">mixed_temp</span>
  307. <span class="n">torch</span><span class="o">.</span><span class="n">rint</span><span class="p">(</span><span class="n">mixed_j</span><span class="p">,</span> <span class="n">out</span><span class="o">=</span><span class="n">mixed_j</span><span class="p">)</span>
  308. <span class="n">torch</span><span class="o">.</span><span class="n">rint</span><span class="p">(</span><span class="n">mixed_i</span><span class="p">,</span> <span class="n">out</span><span class="o">=</span><span class="n">mixed_i</span><span class="p">)</span>
  309. <span class="n">output</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+=</span> <span class="n">mixed_i</span>
  310. <span class="n">output</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="o">+=</span> <span class="n">mixed_j</span>
  311. <span class="n">lam_batch</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">lam_batch</span><span class="p">,</span> <span class="n">lam_batch</span><span class="p">[::</span><span class="o">-</span><span class="mi">1</span><span class="p">]))</span>
  312. <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">lam_batch</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
  313. <span class="k">def</span> <span class="nf">_mix_batch_collate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="nb">list</span><span class="p">):</span>
  314. <span class="sd">&quot;&quot;&quot;</span>
  315. <span class="sd"> This is the implementation for &#39;batch&#39; mode</span>
  316. <span class="sd"> :param output: the output tensor to fill</span>
  317. <span class="sd"> :param batch: list of thr batch items</span>
  318. <span class="sd"> :return: the lambda value used for the mixing</span>
  319. <span class="sd"> &quot;&quot;&quot;</span>
  320. <span class="n">batch_size</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
  321. <span class="n">lam</span><span class="p">,</span> <span class="n">use_cutmix</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_params_per_batch</span><span class="p">()</span>
  322. <span class="k">if</span> <span class="n">use_cutmix</span><span class="p">:</span>
  323. <span class="p">(</span><span class="n">yl</span><span class="p">,</span> <span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">,</span> <span class="n">xh</span><span class="p">),</span> <span class="n">lam</span> <span class="o">=</span> <span class="n">cutmix_bbox_and_lam</span><span class="p">(</span>
  324. <span class="n">output</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">lam</span><span class="p">,</span> <span class="n">ratio_minmax</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">cutmix_minmax</span><span class="p">,</span> <span class="n">correct_lam</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">correct_lam</span><span class="p">)</span>
  325. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span>
  326. <span class="n">j</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="o">-</span> <span class="n">i</span> <span class="o">-</span> <span class="mi">1</span>
  327. <span class="n">mixed</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
  328. <span class="k">if</span> <span class="n">lam</span> <span class="o">!=</span> <span class="mf">1.</span><span class="p">:</span>
  329. <span class="k">if</span> <span class="n">use_cutmix</span><span class="p">:</span>
  330. <span class="n">mixed</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clone</span><span class="p">(</span><span class="n">mixed</span><span class="p">)</span> <span class="c1"># don&#39;t want to modify the original while iterating</span>
  331. <span class="n">mixed</span><span class="p">[:,</span> <span class="n">yl</span><span class="p">:</span><span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">:</span><span class="n">xh</span><span class="p">]</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">0</span><span class="p">][:,</span> <span class="n">yl</span><span class="p">:</span><span class="n">yh</span><span class="p">,</span> <span class="n">xl</span><span class="p">:</span><span class="n">xh</span><span class="p">]</span>
  332. <span class="k">else</span><span class="p">:</span>
  333. <span class="n">mixed</span> <span class="o">=</span> <span class="n">mixed</span> <span class="o">*</span> <span class="n">lam</span> <span class="o">+</span> <span class="n">batch</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">lam</span><span class="p">)</span>
  334. <span class="n">output</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+=</span> <span class="n">mixed</span>
  335. <span class="k">return</span> <span class="n">lam</span>
  336. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">_</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  337. <span class="n">batch_size</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch</span><span class="p">)</span>
  338. <span class="k">if</span> <span class="n">batch_size</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
  339. <span class="k">raise</span> <span class="n">IllegalDatasetParameterException</span><span class="p">(</span><span class="s1">&#39;Batch size should be even when using this&#39;</span><span class="p">)</span>
  340. <span class="n">half</span> <span class="o">=</span> <span class="s1">&#39;half&#39;</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span>
  341. <span class="k">if</span> <span class="n">half</span><span class="p">:</span>
  342. <span class="n">batch_size</span> <span class="o">//=</span> <span class="mi">2</span>
  343. <span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">*</span><span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
  344. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;elem&#39;</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;half&#39;</span><span class="p">:</span>
  345. <span class="n">lam</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_mix_elem_collate</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">half</span><span class="o">=</span><span class="n">half</span><span class="p">)</span>
  346. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;pair&#39;</span><span class="p">:</span>
  347. <span class="n">lam</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_mix_pair_collate</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
  348. <span class="k">else</span><span class="p">:</span>
  349. <span class="n">lam</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_mix_batch_collate</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span>
  350. <span class="n">target</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="n">b</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="n">batch</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
  351. <span class="n">target</span> <span class="o">=</span> <span class="n">mixup_target</span><span class="p">(</span><span class="n">target</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">lam</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">label_smoothing</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span>
  352. <span class="n">target</span> <span class="o">=</span> <span class="n">target</span><span class="p">[:</span><span class="n">batch_size</span><span class="p">]</span>
  353. <span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">target</span></div>
  354. </pre></div>
  355. </div>
  356. </div>
  357. <footer>
  358. <hr/>
  359. <div role="contentinfo">
  360. <p>&#169; Copyright 2021, SuperGradients team.</p>
  361. </div>
  362. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  363. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  364. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  365. </footer>
  366. </div>
  367. </div>
  368. </section>
  369. </div>
  370. <script>
  371. jQuery(function () {
  372. SphinxRtdTheme.Navigation.enable(true);
  373. });
  374. </script>
  375. </body>
  376. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../../_static/jquery.js"></script>
  15. <script src="../../../../../_static/underscore.js"></script>
  16. <script src="../../../../../_static/doctools.js"></script>
  17. <script src="../../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">os</span>
  84. <span class="kn">import</span> <span class="nn">cv2</span>
  85. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  86. <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span><span class="p">,</span> <span class="n">ImageColor</span>
  87. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</span> <span class="kn">import</span> <span class="n">SegmentationDataSet</span>
  88. <span class="c1"># TODO - ADD COARSE DATA - right now cityscapes dataset includes fine annotations. It&#39;s optional to use extra coarse</span>
  89. <span class="c1"># annotations.</span>
  90. <span class="c1"># label for background and labels to ignore during training and evaluation.</span>
  91. <span class="n">CITYSCAPES_IGNORE_LABEL</span> <span class="o">=</span> <span class="mi">19</span>
  92. <div class="viewcode-block" id="CityscapesDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation.CityscapesDataset">[docs]</a><span class="k">class</span> <span class="nc">CityscapesDataset</span><span class="p">(</span><span class="n">SegmentationDataSet</span><span class="p">):</span>
  93. <span class="sd">&quot;&quot;&quot;</span>
  94. <span class="sd"> CityscapesDataset - Segmentation Data Set Class for Cityscapes Segmentation Data Set,</span>
  95. <span class="sd"> main resolution of dataset: (2048 x 1024).</span>
  96. <span class="sd"> Not all the original labels are used for training and evaluation, according to cityscape paper:</span>
  97. <span class="sd"> &quot;Classes that are too rare are excluded from our benchmark, leaving 19 classes for evaluation&quot;.</span>
  98. <span class="sd"> For more details about the dataset labels format see:</span>
  99. <span class="sd"> https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py</span>
  100. <span class="sd"> &quot;&quot;&quot;</span>
  101. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
  102. <span class="n">root_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
  103. <span class="n">list_file</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
  104. <span class="n">labels_csv_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
  105. <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  106. <span class="sd">&quot;&quot;&quot;</span>
  107. <span class="sd"> :param root: root directory to dataset.</span>
  108. <span class="sd"> :param list_file: list file that contains names of images to load, line format: &lt;image_path&gt; &lt;label_path&gt;</span>
  109. <span class="sd"> :param labels_csv_path: path to csv file, with labels metadata and mapping.</span>
  110. <span class="sd"> :param kwargs: Any hyper params required for the dataset, i.e img_size, crop_size, cache_images</span>
  111. <span class="sd"> &quot;&quot;&quot;</span>
  112. <span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span> <span class="o">=</span> <span class="n">root_dir</span>
  113. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">root_dir</span><span class="p">,</span> <span class="n">list_file</span><span class="o">=</span><span class="n">list_file</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  114. <span class="c1"># labels dataframe for labels metadata.</span>
  115. <span class="bp">self</span><span class="o">.</span><span class="n">labels_data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">recfromcsv</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span><span class="p">,</span> <span class="n">labels_csv_path</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="s1">&#39;&lt;i8,U20,&lt;i8,&lt;i8,U12,&lt;i8,?,?,U7&#39;</span><span class="p">,</span> <span class="n">comments</span><span class="o">=</span><span class="s1">&#39;&amp;&#39;</span><span class="p">)</span>
  116. <span class="c1"># map vector to map ground-truth labels to train labels</span>
  117. <span class="bp">self</span><span class="o">.</span><span class="n">labels_map</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">labels_data</span><span class="o">.</span><span class="n">field</span><span class="p">(</span><span class="s2">&quot;trainid&quot;</span><span class="p">)</span>
  118. <span class="c1"># class names</span>
  119. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">labels_data</span><span class="o">.</span><span class="n">field</span><span class="p">(</span><span class="s2">&quot;name&quot;</span><span class="p">)[</span><span class="n">np</span><span class="o">.</span><span class="n">logical_not</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">labels_data</span><span class="o">.</span><span class="n">field</span><span class="p">(</span><span class="s2">&quot;ignoreineval&quot;</span><span class="p">))]</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
  120. <span class="c1"># color palette for visualization</span>
  121. <span class="bp">self</span><span class="o">.</span><span class="n">train_id_color_palette</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_create_color_palette</span><span class="p">()</span>
  122. <span class="k">def</span> <span class="nf">_generate_samples_and_targets</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  123. <span class="sd">&quot;&quot;&quot;</span>
  124. <span class="sd"> override _generate_samples_and_targets function, to parse list file.</span>
  125. <span class="sd"> line format of list file: &lt;image_path&gt; &lt;label_path&gt;</span>
  126. <span class="sd"> &quot;&quot;&quot;</span>
  127. <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">list_file_path</span><span class="p">))</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
  128. <span class="n">img_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">line</span><span class="o">.</span><span class="n">strip</span><span class="p">()</span><span class="o">.</span><span class="n">split</span><span class="p">()</span> <span class="k">for</span> <span class="n">line</span> <span class="ow">in</span> <span class="n">f</span><span class="p">]</span>
  129. <span class="k">for</span> <span class="n">image_path</span><span class="p">,</span> <span class="n">label_path</span> <span class="ow">in</span> <span class="n">img_list</span><span class="p">:</span>
  130. <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="o">.</span><span class="n">append</span><span class="p">((</span>
  131. <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="p">,</span> <span class="n">image_path</span><span class="p">),</span>
  132. <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="p">,</span> <span class="n">label_path</span><span class="p">)</span>
  133. <span class="p">))</span>
  134. <div class="viewcode-block" id="CityscapesDataset.target_loader"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation.CityscapesDataset.target_loader">[docs]</a> <span class="k">def</span> <span class="nf">target_loader</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">label_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Image</span><span class="p">:</span>
  135. <span class="sd">&quot;&quot;&quot;</span>
  136. <span class="sd"> Override target_loader function, load the labels mask image.</span>
  137. <span class="sd"> :param label_path: Path to the label image.</span>
  138. <span class="sd"> :return: The mask image created from the array, with converted class labels.</span>
  139. <span class="sd"> &quot;&quot;&quot;</span>
  140. <span class="c1"># assert that is a png file, other file types might alter the class labels value.</span>
  141. <span class="k">assert</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">splitext</span><span class="p">(</span><span class="n">label_path</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span> <span class="o">==</span> <span class="s2">&quot;.png&quot;</span>
  142. <span class="n">label</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">imread</span><span class="p">(</span><span class="n">label_path</span><span class="p">,</span> <span class="n">cv2</span><span class="o">.</span><span class="n">IMREAD_GRAYSCALE</span><span class="p">)</span>
  143. <span class="c1"># map ground-truth ids to train ids</span>
  144. <span class="n">label</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">labels_map</span><span class="p">[</span><span class="n">label</span><span class="p">]</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
  145. <span class="k">return</span> <span class="n">Image</span><span class="o">.</span><span class="n">fromarray</span><span class="p">(</span><span class="n">label</span><span class="p">,</span> <span class="s1">&#39;L&#39;</span><span class="p">)</span></div>
  146. <span class="k">def</span> <span class="nf">_create_color_palette</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  147. <span class="sd">&quot;&quot;&quot;</span>
  148. <span class="sd"> Create color pallete for visualizing the segmentation masks</span>
  149. <span class="sd"> :return: list of rgb color values</span>
  150. <span class="sd"> &quot;&quot;&quot;</span>
  151. <span class="n">palette</span> <span class="o">=</span> <span class="p">[]</span>
  152. <span class="n">hex_colors</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">labels_data</span><span class="o">.</span><span class="n">field</span><span class="p">(</span><span class="s2">&quot;color&quot;</span><span class="p">)[</span><span class="n">np</span><span class="o">.</span><span class="n">logical_not</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">labels_data</span><span class="o">.</span><span class="n">field</span><span class="p">(</span><span class="s2">&quot;ignoreineval&quot;</span><span class="p">))]</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span>
  153. <span class="k">for</span> <span class="n">hex_color</span> <span class="ow">in</span> <span class="n">hex_colors</span><span class="p">:</span>
  154. <span class="n">rgb_color</span> <span class="o">=</span> <span class="n">ImageColor</span><span class="o">.</span><span class="n">getcolor</span><span class="p">(</span><span class="n">hex_color</span><span class="p">,</span> <span class="s2">&quot;RGB&quot;</span><span class="p">)</span>
  155. <span class="n">palette</span> <span class="o">+=</span> <span class="p">[</span><span class="n">x</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">rgb_color</span><span class="p">]</span>
  156. <span class="k">return</span> <span class="n">palette</span>
  157. <div class="viewcode-block" id="CityscapesDataset.get_train_ids_color_palette"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation.CityscapesDataset.get_train_ids_color_palette">[docs]</a> <span class="k">def</span> <span class="nf">get_train_ids_color_palette</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  158. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_id_color_palette</span></div>
  159. <div class="viewcode-block" id="CityscapesDataset.target_transform"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation.CityscapesDataset.target_transform">[docs]</a> <span class="nd">@staticmethod</span>
  160. <span class="k">def</span> <span class="nf">target_transform</span><span class="p">(</span><span class="n">target</span><span class="p">):</span>
  161. <span class="sd">&quot;&quot;&quot;</span>
  162. <span class="sd"> target_transform - Transforms the sample image</span>
  163. <span class="sd"> This function overrides the original function from SegmentationDataSet and changes target pixels with value</span>
  164. <span class="sd"> 255 to value = CITYSCAPES_IGNORE_LABEL. This was done since current IoU metric from torchmetrics does not</span>
  165. <span class="sd"> support such a high ignore label value (crashed on OOM)</span>
  166. <span class="sd"> :param target: The target mask to transform</span>
  167. <span class="sd"> :return: The transformed target mask</span>
  168. <span class="sd"> &quot;&quot;&quot;</span>
  169. <span class="n">out</span> <span class="o">=</span> <span class="n">SegmentationDataSet</span><span class="o">.</span><span class="n">target_transform</span><span class="p">(</span><span class="n">target</span><span class="p">)</span>
  170. <span class="n">out</span><span class="p">[</span><span class="n">out</span> <span class="o">==</span> <span class="mi">255</span><span class="p">]</span> <span class="o">=</span> <span class="n">CITYSCAPES_IGNORE_LABEL</span>
  171. <span class="k">return</span> <span class="n">out</span></div></div>
  172. </pre></div>
  173. </div>
  174. </div>
  175. <footer>
  176. <hr/>
  177. <div role="contentinfo">
  178. <p>&#169; Copyright 2021, SuperGradients team.</p>
  179. </div>
  180. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  181. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  182. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  183. </footer>
  184. </div>
  185. </div>
  186. </section>
  187. </div>
  188. <script>
  189. jQuery(function () {
  190. SphinxRtdTheme.Navigation.enable(true);
  191. });
  192. </script>
  193. </body>
  194. </html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.datasets.segmentation_datasets.coco_segmentation &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.datasets.segmentation_datasets.coco_segmentation &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
+        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
+        <script src="../../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -87,14 +89,11 @@
              
              
   <h1>Source code for super_gradients.training.datasets.segmentation_datasets.coco_segmentation</h1><div class="highlight"><pre>
   <h1>Source code for super_gradients.training.datasets.segmentation_datasets.coco_segmentation</h1><div class="highlight"><pre>
 <span></span><span class="kn">import</span> <span class="nn">os</span>
 <span></span><span class="kn">import</span> <span class="nn">os</span>
-<span class="kn">import</span> <span class="nn">torch</span>
+
 <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
 <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
-<span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
+<span class="kn">import</span> <span class="nn">torch</span>
 <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
 <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
-<span class="kn">import</span> <span class="nn">torchvision.transforms</span> <span class="k">as</span> <span class="nn">transform</span>
-<span class="kn">from</span> <span class="nn">super_gradients.training.transforms.transforms</span> <span class="kn">import</span> <span class="n">RandomFlip</span><span class="p">,</span> <span class="n">Rescale</span><span class="p">,</span> <span class="n">RandomRescale</span><span class="p">,</span> <span class="n">CropImageAndMask</span><span class="p">,</span> \
-    <span class="n">PadShortToCropSize</span>
-<span class="kn">from</span> <span class="nn">super_gradients.training.utils.utils</span> <span class="kn">import</span> <span class="n">get_param</span>
+<span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
 
 
 <span class="k">try</span><span class="p">:</span>
 <span class="k">try</span><span class="p">:</span>
     <span class="kn">from</span> <span class="nn">pycocotools.coco</span> <span class="kn">import</span> <span class="n">COCO</span>
     <span class="kn">from</span> <span class="nn">pycocotools.coco</span> <span class="kn">import</span> <span class="n">COCO</span>
@@ -106,36 +105,22 @@
 <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</span> <span class="kn">import</span> <span class="n">SegmentationDataSet</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</span> <span class="kn">import</span> <span class="n">SegmentationDataSet</span>
 
 
 
 
-<div class="viewcode-block" id="EmptyCoCoClassesSelectionException"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.EmptyCoCoClassesSelectionException">[docs]</a><span class="k">class</span> <span class="nc">EmptyCoCoClassesSelectionException</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
-    <span class="k">pass</span></div>
+<span class="k">class</span> <span class="nc">EmptyCoCoClassesSelectionException</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
+    <span class="k">pass</span>
 
 
 
 
-<div class="viewcode-block" id="CoCoSegmentationDataSet"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.CoCoSegmentationDataSet">[docs]</a><span class="k">class</span> <span class="nc">CoCoSegmentationDataSet</span><span class="p">(</span><span class="n">SegmentationDataSet</span><span class="p">):</span>
+<div class="viewcode-block" id="CoCoSegmentationDataSet"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.CoCoSegmentationDataSet">[docs]</a><span class="k">class</span> <span class="nc">CoCoSegmentationDataSet</span><span class="p">(</span><span class="n">SegmentationDataSet</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    CoCoSegmentationDataSet - Segmentation Data Set Class for COCO 2017 Segmentation Data Set</span>
 <span class="sd">    CoCoSegmentationDataSet - Segmentation Data Set Class for COCO 2017 Segmentation Data Set</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 
 
-    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_classes_inclusion_tuples_list</span><span class="p">:</span> <span class="nb">list</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">root_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+                 <span class="n">dataset_classes_inclusion_tuples_list</span><span class="p">:</span> <span class="nb">list</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
         <span class="c1"># THERE ARE 91 CLASSES, INCLUDING BACKGROUND - BUT WE ENABLE THE USAGE OF SUBCLASSES, TO PARTIALLY USE THE DATA</span>
         <span class="c1"># THERE ARE 91 CLASSES, INCLUDING BACKGROUND - BUT WE ENABLE THE USAGE OF SUBCLASSES, TO PARTIALLY USE THE DATA</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">dataset_classes_inclusion_tuples_list</span> <span class="o">=</span> <span class="n">dataset_classes_inclusion_tuples_list</span> <span class="ow">or</span> <span class="n">COCO_DEFAULT_CLASSES_TUPLES_LIST</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">dataset_classes_inclusion_tuples_list</span> <span class="o">=</span> <span class="n">dataset_classes_inclusion_tuples_list</span> <span class="ow">or</span> <span class="n">COCO_DEFAULT_CLASSES_TUPLES_LIST</span>
 
 
-        <span class="c1"># OVERRIDE DEFAULT AUGMENTATIONS, IMG_SIZE, CROP SIZE</span>
-        <span class="n">dataset_hyper_params</span> <span class="o">=</span> <span class="n">kwargs</span><span class="p">[</span><span class="s1">&#39;dataset_hyper_params&#39;</span><span class="p">]</span>
-        <span class="n">kwargs</span><span class="p">[</span><span class="s1">&#39;img_size&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">dataset_hyper_params</span><span class="p">,</span> <span class="s1">&#39;img_size&#39;</span><span class="p">,</span> <span class="mi">608</span><span class="p">)</span>
-        <span class="n">kwargs</span><span class="p">[</span><span class="s1">&#39;crop_size&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">dataset_hyper_params</span><span class="p">,</span> <span class="s1">&#39;crop_size&#39;</span><span class="p">,</span> <span class="mi">512</span><span class="p">)</span>
-        <span class="c1"># FIXME - Rescale before RandomRescale is kept for legacy support, consider removing it like most implementation</span>
-        <span class="c1">#  papers regimes.</span>
-        <span class="n">default_image_mask_transforms_aug</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span><span class="n">RandomFlip</span><span class="p">(),</span>
-                                                               <span class="n">Rescale</span><span class="p">(</span><span class="n">long_size</span><span class="o">=</span><span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;img_size&quot;</span><span class="p">]),</span>
-                                                               <span class="n">RandomRescale</span><span class="p">(</span><span class="n">scales</span><span class="o">=</span><span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">)),</span>
-                                                               <span class="n">PadShortToCropSize</span><span class="p">(</span><span class="n">crop_size</span><span class="o">=</span><span class="n">kwargs</span><span class="p">[</span><span class="s1">&#39;crop_size&#39;</span><span class="p">]),</span>
-                                                               <span class="n">CropImageAndMask</span><span class="p">(</span><span class="n">crop_size</span><span class="o">=</span><span class="n">kwargs</span><span class="p">[</span><span class="s1">&#39;crop_size&#39;</span><span class="p">],</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;random&quot;</span><span class="p">)])</span>
-        <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;image_mask_transforms_aug&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">dataset_hyper_params</span><span class="p">,</span> <span class="s2">&quot;image_mask_transforms_aug&quot;</span><span class="p">,</span>
-                                                        <span class="n">default_image_mask_transforms_aug</span><span class="p">)</span>
-        <span class="n">image_mask_transforms</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">dataset_hyper_params</span><span class="p">,</span> <span class="s1">&#39;image_mask_transforms&#39;</span><span class="p">)</span>
-        <span class="k">if</span> <span class="n">image_mask_transforms</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
-            <span class="n">kwargs</span><span class="p">[</span><span class="s2">&quot;image_mask_transforms&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">image_mask_transforms</span>
-        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">root_dir</span> <span class="o">=</span> <span class="n">root_dir</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">root_dir</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
 
 
         <span class="n">_</span><span class="p">,</span> <span class="n">class_names</span> <span class="o">=</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_classes_inclusion_tuples_list</span><span class="p">)</span>
         <span class="n">_</span><span class="p">,</span> <span class="n">class_names</span> <span class="o">=</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_classes_inclusion_tuples_list</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">class_names</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">class_names</span>
@@ -162,7 +147,9 @@
             <span class="n">mask_metadata_tuple</span> <span class="o">=</span> <span class="p">(</span><span class="n">relevant_image_id</span><span class="p">,</span> <span class="n">img_metadata</span><span class="p">[</span><span class="s1">&#39;height&#39;</span><span class="p">],</span> <span class="n">img_metadata</span><span class="p">[</span><span class="s1">&#39;width&#39;</span><span class="p">])</span>
             <span class="n">mask_metadata_tuple</span> <span class="o">=</span> <span class="p">(</span><span class="n">relevant_image_id</span><span class="p">,</span> <span class="n">img_metadata</span><span class="p">[</span><span class="s1">&#39;height&#39;</span><span class="p">],</span> <span class="n">img_metadata</span><span class="p">[</span><span class="s1">&#39;width&#39;</span><span class="p">])</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">image_path</span><span class="p">,</span> <span class="n">mask_metadata_tuple</span><span class="p">))</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">image_path</span><span class="p">,</span> <span class="n">mask_metadata_tuple</span><span class="p">))</span>
 
 
-<div class="viewcode-block" id="CoCoSegmentationDataSet.target_loader"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.CoCoSegmentationDataSet.target_loader">[docs]</a>    <span class="k">def</span> <span class="nf">target_loader</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mask_metadata_tuple</span><span class="p">)</span> <span class="o">-&gt;</spa
+        <span class="nb">super</span><span class="p">(</span><span class="n">CoCoSegmentationDataSet</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">_generate_samples_and_targets</span><span class="p">()</span>
+
+<div class="viewcode-block" id="CoCoSegmentationDataSet.target_loader"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.CoCoSegmentationDataSet.target_loader">[docs]</a>    <span class="k">def</span> <span class="nf">target_loader</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mask_metadata_tuple</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Image</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        target_loader</span>
 <span class="sd">        target_loader</span>
 <span class="sd">            :param mask_metadata_tuple:  A tuple of (coco_image_id, original_image_height, original_image_width)</span>
 <span class="sd">            :param mask_metadata_tuple:  A tuple of (coco_image_id, original_image_height, original_image_width)</span>
@@ -267,4 +254,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.datasets.segmentation_datasets.pascal_aug_segmentation &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../../_static/jquery.js"></script>
  15. <script src="../../../../../_static/underscore.js"></script>
  16. <script src="../../../../../_static/doctools.js"></script>
  17. <script src="../../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.datasets.segmentation_datasets.pascal_aug_segmentation</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.datasets.segmentation_datasets.pascal_aug_segmentation</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">scipy.io</span>
  84. <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
  85. <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets.pascal_voc_segmentation</span> <span class="kn">import</span> <span class="n">PascalVOC2012SegmentationDataSet</span>
  86. <span class="n">PASCAL_AUG_CLASSES</span> <span class="o">=</span> <span class="p">[</span>
  87. <span class="s1">&#39;background&#39;</span><span class="p">,</span> <span class="s1">&#39;airplane&#39;</span><span class="p">,</span> <span class="s1">&#39;bicycle&#39;</span><span class="p">,</span> <span class="s1">&#39;bird&#39;</span><span class="p">,</span> <span class="s1">&#39;boat&#39;</span><span class="p">,</span> <span class="s1">&#39;bottle&#39;</span><span class="p">,</span>
  88. <span class="s1">&#39;bus&#39;</span><span class="p">,</span> <span class="s1">&#39;car&#39;</span><span class="p">,</span> <span class="s1">&#39;cat&#39;</span><span class="p">,</span> <span class="s1">&#39;chair&#39;</span><span class="p">,</span> <span class="s1">&#39;cow&#39;</span><span class="p">,</span> <span class="s1">&#39;diningtable&#39;</span><span class="p">,</span> <span class="s1">&#39;dog&#39;</span><span class="p">,</span> <span class="s1">&#39;horse&#39;</span><span class="p">,</span>
  89. <span class="s1">&#39;motorcycle&#39;</span><span class="p">,</span> <span class="s1">&#39;person&#39;</span><span class="p">,</span> <span class="s1">&#39;potted-plant&#39;</span><span class="p">,</span> <span class="s1">&#39;sheep&#39;</span><span class="p">,</span> <span class="s1">&#39;sofa&#39;</span><span class="p">,</span> <span class="s1">&#39;train&#39;</span><span class="p">,</span>
  90. <span class="s1">&#39;tv&#39;</span>
  91. <span class="p">]</span>
  92. <div class="viewcode-block" id="PascalAUG2012SegmentationDataSet"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.PascalAUG2012SegmentationDataSet">[docs]</a><span class="k">class</span> <span class="nc">PascalAUG2012SegmentationDataSet</span><span class="p">(</span><span class="n">PascalVOC2012SegmentationDataSet</span><span class="p">):</span>
  93. <span class="sd">&quot;&quot;&quot;</span>
  94. <span class="sd"> PascalAUG2012SegmentationDataSet - Segmentation Data Set Class for Pascal AUG 2012 Data Set</span>
  95. <span class="sd"> &quot;&quot;&quot;</span>
  96. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  97. <span class="bp">self</span><span class="o">.</span><span class="n">sample_suffix</span> <span class="o">=</span> <span class="s1">&#39;.jpg&#39;</span>
  98. <span class="bp">self</span><span class="o">.</span><span class="n">target_suffix</span> <span class="o">=</span> <span class="s1">&#39;.mat&#39;</span>
  99. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">sample_suffix</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sample_suffix</span><span class="p">,</span> <span class="n">target_suffix</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">target_suffix</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  100. <span class="c1"># THERE ARE 21 CLASSES, INCLUDING BACKGROUND</span>
  101. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">PASCAL_AUG_CLASSES</span>
  102. <div class="viewcode-block" id="PascalAUG2012SegmentationDataSet.target_loader"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.PascalAUG2012SegmentationDataSet.target_loader">[docs]</a> <span class="nd">@staticmethod</span>
  103. <span class="k">def</span> <span class="nf">target_loader</span><span class="p">(</span><span class="n">target_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Image</span><span class="p">:</span>
  104. <span class="sd">&quot;&quot;&quot;</span>
  105. <span class="sd"> target_loader</span>
  106. <span class="sd"> :param target_path: The path to the target data</span>
  107. <span class="sd"> :return: The loaded target</span>
  108. <span class="sd"> &quot;&quot;&quot;</span>
  109. <span class="n">mat</span> <span class="o">=</span> <span class="n">scipy</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">loadmat</span><span class="p">(</span><span class="n">target_path</span><span class="p">,</span> <span class="n">mat_dtype</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">squeeze_me</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
  110. <span class="n">struct_as_record</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
  111. <span class="n">mask</span> <span class="o">=</span> <span class="n">mat</span><span class="p">[</span><span class="s1">&#39;GTcls&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">Segmentation</span>
  112. <span class="k">return</span> <span class="n">Image</span><span class="o">.</span><span class="n">fromarray</span><span class="p">(</span><span class="n">mask</span><span class="p">)</span></div></div>
  113. </pre></div>
  114. </div>
  115. </div>
  116. <footer>
  117. <hr/>
  118. <div role="contentinfo">
  119. <p>&#169; Copyright 2021, SuperGradients team.</p>
  120. </div>
  121. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  122. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  123. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  124. </footer>
  125. </div>
  126. </div>
  127. </section>
  128. </div>
  129. <script>
  130. jQuery(function () {
  131. SphinxRtdTheme.Navigation.enable(true);
  132. });
  133. </script>
  134. </body>
  135. </html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.datasets.segmentation_datasets.pascal_voc_segmentation &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.datasets.segmentation_datasets.pascal_voc_segmentation &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
+        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
+        <script src="../../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -89,31 +91,71 @@
 <span></span><span class="kn">import</span> <span class="nn">os</span>
 <span></span><span class="kn">import</span> <span class="nn">os</span>
 
 
 <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
 <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
+<span class="kn">import</span> <span class="nn">scipy.io</span>
+<span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
+<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">ConcatDataset</span>
 
 
 <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</span> <span class="kn">import</span> <span class="n">SegmentationDataSet</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</span> <span class="kn">import</span> <span class="n">SegmentationDataSet</span>
+<span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
+
+<span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
 
 
 <span class="n">PASCAL_VOC_2012_CLASSES</span> <span class="o">=</span> <span class="p">[</span>
 <span class="n">PASCAL_VOC_2012_CLASSES</span> <span class="o">=</span> <span class="p">[</span>
-    <span class="s1">&#39;background&#39;</span><span class="p">,</span> <span class="s1">&#39;aeroplane&#39;</span><span class="p">,</span> <span class="s1">&#39;bicycle&#39;</span><span class="p">,</span> <span class="s1">&#39;bird&#39;</span><span class="p">,</span> <span class="s1">&#39;boat&#39;</span><span class="p">,</span> <span class="s1">&#39;bottle&#39;</span><span class="p">,</span>
-    <span class="s1">&#39;bus&#39;</span><span class="p">,</span> <span class="s1">&#39;car&#39;</span><span class="p">,</span> <span class="s1">&#39;cat&#39;</span><span class="p">,</span> <span class="s1">&#39;chair&#39;</span><span class="p">,</span> <span class="s1">&#39;cow&#39;</span><span class="p">,</span> <span class="s1">&#39;diningtable&#39;</span><span class="p">,</span> <span class="s1">&#39;dog&#39;</span><span class="p">,</span> <span class="s1">&#39;horse&#39;</span><span class=
-    <span class="s1">&#39;motorbike&#39;</span><span class="p">,</span> <span class="s1">&#39;person&#39;</span><span class="p">,</span> <span class="s1">&#39;potted-plant&#39;</span><span class="p">,</span> <span class="s1">&#39;sheep&#39;</span><span class="p">,</span> <span class="s1">&#39;sofa&#39;</span><span class="p">,</span> <span class="s1">&#39;train&#39;</span><span class="p">,</span>
-    <span class="s1">&#39;tv/monitor&#39;</span><span class="p">,</span> <span class="s1">&#39;ambigious&#39;</span>
+    <span class="s2">&quot;background&quot;</span><span class="p">,</span>
+    <span class="s2">&quot;aeroplane&quot;</span><span class="p">,</span>
+    <span class="s2">&quot;bicycle&quot;</span><span class="p">,</span>
+    <span class="s2">&quot;bird&quot;</span><span class="p">,</span>
+    <span class="s2">&quot;boat&quot;</span><span class="p">,</span>
+    <span class="s2">&quot;bottle&quot;</span><span class="p">,</span>
+    <span class="s2">&quot;bus&quot;</span><span class="p">,</span>
+    <span class="s2">&quot;car&quot;</span><span class="p">,</span>
+    <span class="s2">&quot;cat&quot;</span><span class="p">,</span>
+    <span class="s2">&quot;chair&quot;</span><span class="p">,</span>
+    <span class="s2">&quot;cow&quot;</span><span class="p">,</span>
+    <span class="s2">&quot;diningtable&quot;</span><span class="p">,</span>
+    <span class="s2">&quot;dog&quot;</span><span class="p">,</span>
+    <span class="s2">&quot;horse&quot;</span><span class="p">,</span>
+    <span class="s2">&quot;motorbike&quot;</span><span class="p">,</span>
+    <span class="s2">&quot;person&quot;</span><span class="p">,</span>
+    <span class="s2">&quot;potted-plant&quot;</span><span class="p">,</span>
+    <span class="s2">&quot;sheep&quot;</span><span class="p">,</span>
+    <span class="s2">&quot;sofa&quot;</span><span class="p">,</span>
+    <span class="s2">&quot;train&quot;</span><span class="p">,</span>
+    <span class="s2">&quot;tv/monitor&quot;</span><span class="p">,</span>
 <span class="p">]</span>
 <span class="p">]</span>
 
 
 
 
-<div class="viewcode-block" id="PascalVOC2012SegmentationDataSet"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.PascalVOC2012SegmentationDataSet">[docs]</a><span class="k">class</span> <span class="nc">PascalVOC2012SegmentationDataSet</span><span class="p">(</span><span class="n">SegmentationDataSet</span><span class="p">):</span>
+<div class="viewcode-block" id="PascalVOC2012SegmentationDataSet"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.PascalVOC2012SegmentationDataSet">[docs]</a><span class="k">class</span> <span class="nc">PascalVOC2012SegmentationDataSet</span><span class="p">(</span><span class="n">SegmentationDataSet</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    PascalVOC2012SegmentationDataSet - Segmentation Data Set Class for Pascal VOC 2012 Data Set</span>
 <span class="sd">    PascalVOC2012SegmentationDataSet - Segmentation Data Set Class for Pascal VOC 2012 Data Set</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 
 
+    <span class="n">IGNORE_LABEL</span> <span class="o">=</span> <span class="mi">21</span>
+    <span class="n">_ORIGINAL_IGNORE_LABEL</span> <span class="o">=</span> <span class="mi">255</span>
+
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample_suffix</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">target_suffix</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><sp
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample_suffix</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">target_suffix</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><sp
-        <span class="bp">self</span><span class="o">.</span><span class="n">sample_suffix</span> <span class="o">=</span> <span class="s1">&#39;.jpg&#39;</span> <span class="k">if</span> <span class="n">sample_suffix</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">sample_suffix</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">target_suffix</span> <span class="o">=</span> <span class="s1">&#39;.png&#39;</span> <span class="k">if</span> <span class="n">target_suffix</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">target_suffix</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">sample_suffix</span> <span class="o">=</span> <span class="s2">&quot;.jpg&quot;</span> <span class="k">if</span> <span class="n">sample_suffix</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">sample_suffix</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">target_suffix</span> <span class="o">=</span> <span class="s2">&quot;.png&quot;</span> <span class="k">if</span> <span class="n">target_suffix</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">target_suffix</span>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
 
 
-        <span class="c1"># THERE ARE 21 CLASSES, AND BACKGROUND</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">PASCAL_VOC_2012_CLASSES</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">PASCAL_VOC_2012_CLASSES</span>
 
 
-<div class="viewcode-block" id="PascalVOC2012SegmentationDataSet.decode_segmentation_mask"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.PascalVOC2012SegmentationDataSet.decode_segmentation_mask">[docs]</a>    <span class="k">def</span> <span class="nf">decode_segmentation_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">label_mask</span><span cla
+<div class="viewcode-block" id="PascalVOC2012SegmentationDataSet.target_transform"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.PascalVOC2012SegmentationDataSet.target_transform">[docs]</a>    <span class="nd">@staticmethod</span>
+    <span class="k">def</span> <span class="nf">target_transform</span><span class="p">(</span><span class="n">target</span><span class="p">):</span>
+        <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">        target_transform - Transforms the label mask</span>
+<span class="sd">        This function overrides the original function from SegmentationDataSet and changes target pixels with value</span>
+<span class="sd">        255 to value = IGNORE_LABEL. This was done since current IoU metric from torchmetrics does not</span>
+<span class="sd">        support such a high ignore label value (crashed on OOM)</span>
+
+<span class="sd">            :param target: The target mask to transform</span>
+<span class="sd">            :return:       The transformed target mask</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+        <span class="n">out</span> <span class="o">=</span> <span class="n">SegmentationDataSet</span><span class="o">.</span><span class="n">target_transform</span><span class="p">(</span><span class="n">target</span><span class="p">)</span>
+        <span class="n">out</span><span class="p">[</span><span class="n">out</span> <span class="o">==</span> <span class="n">PascalVOC2012SegmentationDataSet</span><span class="o">.</span><span class="n">_ORIGINAL_IGNORE_LABEL</span><span class="p">]</span> <span class="o">=</span> <span class="n">PascalVOC2012SegmentationDataSet</span><span class="o">.</span><span class="n">IGNORE_LABEL</span>
+        <span class="k">return</span> <span class="n">out</span></div>
+
+<div class="viewcode-block" id="PascalVOC2012SegmentationDataSet.decode_segmentation_mask"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.PascalVOC2012SegmentationDataSet.decode_segmentation_mask">[docs]</a>    <span class="k">def</span> <span class="nf">decode_segmentation_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">label_mask</span><span class="p">:</span> <span class="n"
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        decode_segmentation_mask - Decodes the colors for the Segmentation Mask</span>
 <span class="sd">        decode_segmentation_mask - Decodes the colors for the Segmentation Mask</span>
 <span class="sd">            :param: label_mask:  an (M,N) array of integer values denoting</span>
 <span class="sd">            :param: label_mask:  an (M,N) array of integer values denoting</span>
@@ -125,8 +167,7 @@
         <span class="n">g</span> <span class="o">=</span> <span class="n">label_mask</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
         <span class="n">g</span> <span class="o">=</span> <span class="n">label_mask</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
         <span class="n">b</span> <span class="o">=</span> <span class="n">label_mask</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
         <span class="n">b</span> <span class="o">=</span> <span class="n">label_mask</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
 
 
-        <span class="c1"># REMOVING THE BACKGROUND CLASS FROM THE PLOTS</span>
-        <span class="n">num_classes_to_plot</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">classes</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
+        <span class="n">num_classes_to_plot</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">classes</span><span class="p">)</span>
         <span class="k">for</span> <span class="n">ll</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">num_classes_to_plot</span><span class="p">):</span>
         <span class="k">for</span> <span class="n">ll</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">num_classes_to_plot</span><span class="p">):</span>
             <span class="n">r</span><span class="p">[</span><span class="n">label_mask</span> <span class="o">==</span> <span class="n">ll</span><span class="p">]</span> <span class="o">=</span> <span class="n">label_colours</span><span class="p">[</span><span class="n">ll</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span>
             <span class="n">r</span><span class="p">[</span><span class="n">label_mask</span> <span class="o">==</span> <span class="n">ll</span><span class="p">]</span> <span class="o">=</span> <span class="n">label_colours</span><span class="p">[</span><span class="n">ll</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span>
             <span class="n">g</span><span class="p">[</span><span class="n">label_mask</span> <span class="o">==</span> <span class="n">ll</span><span class="p">]</span> <span class="o">=</span> <span class="n">label_colours</span><span class="p">[</span><span class="n">ll</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span>
             <span class="n">g</span><span class="p">[</span><span class="n">label_mask</span> <span class="o">==</span> <span class="n">ll</span><span class="p">]</span> <span class="o">=</span> <span class="n">label_colours</span><span class="p">[</span><span class="n">ll</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span>
@@ -145,8 +186,8 @@
         <span class="c1"># GENERATE SAMPLES AND TARGETS HERE SPECIFICALLY FOR PASCAL VOC 2012</span>
         <span class="c1"># GENERATE SAMPLES AND TARGETS HERE SPECIFICALLY FOR PASCAL VOC 2012</span>
         <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span> <span class="o">+</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">sep</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">list_file_path</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><sp
         <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span> <span class="o">+</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">sep</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">list_file_path</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><sp
             <span class="k">for</span> <span class="n">line</span> <span class="ow">in</span> <span class="n">lines</span><span class="p">:</span>
             <span class="k">for</span> <span class="n">line</span> <span class="ow">in</span> <span class="n">lines</span><span class="p">:</span>
-                <span class="n">image_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">samples_sub_directory</span><span class="p">,</span> <span class="n">line</span><span class="o">.</sp
-                <span class="n">mask_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">targets_sub_directory</span><span class="p">,</span> <span class="n">line</span><span class="o">.</spa
+                <span class="n">image_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">samples_sub_directory</span><span class="p">,</span> <span class="n">line</span><span class="o">.</sp
+                <span class="n">mask_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">targets_sub_directory</span><span class="p">,</span> <span class="n">line</span><span class="o">.</spa
 
 
                 <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">mask_path</span><span class="p">)</span> <span class="ow">and</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">image_path</span><span class="p">):</span>
                 <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">mask_path</span><span class="p">)</span> <span class="ow">and</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">image_path</span><span class="p">):</span>
                     <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">image_path</span><span class="p">,</span> <span class="n">mask_path</span><span class="p">))</span>
                     <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">image_path</span><span class="p">,</span> <span class="n">mask_path</span><span class="p">))</span>
@@ -184,6 +225,59 @@
                 <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">],</span>
                 <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">],</span>
             <span class="p">]</span>
             <span class="p">]</span>
         <span class="p">)</span></div>
         <span class="p">)</span></div>
+
+
+<div class="viewcode-block" id="PascalAUG2012SegmentationDataSet"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.PascalAUG2012SegmentationDataSet">[docs]</a><span class="k">class</span> <span class="nc">PascalAUG2012SegmentationDataSet</span><span class="p">(</span><span class="n">PascalVOC2012SegmentationDataSet</span><span class="p">):</span>
+    <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">    PascalAUG2012SegmentationDataSet - Segmentation Data Set Class for Pascal AUG 2012 Data Set</span>
+<span class="sd">    &quot;&quot;&quot;</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">sample_suffix</span> <span class="o">=</span> <span class="s2">&quot;.jpg&quot;</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">target_suffix</span> <span class="o">=</span> <span class="s2">&quot;.mat&quot;</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">sample_suffix</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sample_suffix</span><span class="p">,</span> <span class="n">target_suffix</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">target_suffix</span><span class="p">,</span> <span class
+
+<div class="viewcode-block" id="PascalAUG2012SegmentationDataSet.target_loader"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.PascalAUG2012SegmentationDataSet.target_loader">[docs]</a>    <span class="nd">@staticmethod</span>
+    <span class="k">def</span> <span class="nf">target_loader</span><span class="p">(</span><span class="n">target_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Image</span><span class="p">:</span>
+        <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">        target_loader</span>
+<span class="sd">            :param target_path: The path to the target data</span>
+<span class="sd">            :return:            The loaded target</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+        <span class="n">mat</span> <span class="o">=</span> <span class="n">scipy</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">loadmat</span><span class="p">(</span><span class="n">target_path</span><span class="p">,</span> <span class="n">mat_dtype</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">squeeze_me</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span 
+        <span class="n">mask</span> <span class="o">=</span> <span class="n">mat</span><span class="p">[</span><span class="s2">&quot;GTcls&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">Segmentation</span>
+        <span class="k">return</span> <span class="n">Image</span><span class="o">.</span><span class="n">fromarray</span><span class="p">(</span><span class="n">mask</span><span class="p">)</span></div></div>
+
+
+<div class="viewcode-block" id="PascalVOCAndAUGUnifiedDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.PascalVOCAndAUGUnifiedDataset">[docs]</a><span class="k">class</span> <span class="nc">PascalVOCAndAUGUnifiedDataset</span><span class="p">(</span><span class="n">ConcatDataset</span><span class="p">):</span>
+    <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">    Pascal VOC + AUG train dataset, aka `SBD` dataset contributed in &quot;Semantic contours from inverse detectors&quot;.</span>
+<span class="sd">    This is class implement the common usage of the SBD and PascalVOC datasets as a unified augmented trainset.</span>
+<span class="sd">    The unified dataset includes a total of 10,582 samples and don&#39;t contains duplicate samples from the PascalVOC</span>
+<span class="sd">    validation set.</span>
+<span class="sd">    &quot;&quot;&quot;</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
+        <span class="nb">print</span><span class="p">(</span><span class="n">kwargs</span><span class="p">)</span>
+        <span class="k">if</span> <span class="nb">any</span><span class="p">([</span><span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;list_file&quot;</span><span class="p">),</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;samples_sub_directory&quot;</span><span class="p">),</span> <span class="n">kwargs</span><span class="o">.</span>
+            <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
+                <span class="s2">&quot;[list_file, samples_sub_directory, targets_sub_directory] arguments passed will not be used&quot;</span>
+                <span class="s2">&quot; when passed to `PascalVOCAndAUGUnifiedDataset`. Those values are predefined for initiating&quot;</span>
+                <span class="s2">&quot; the Pascal VOC + AUG training set.&quot;</span>
+            <span class="p">)</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
+            <span class="n">datasets</span><span class="o">=</span><span class="p">[</span>
+                <span class="n">PascalVOC2012SegmentationDataSet</span><span class="p">(</span>
+                    <span class="n">list_file</span><span class="o">=</span><span class="s2">&quot;VOCdevkit/VOC2012/ImageSets/Segmentation/train.txt&quot;</span><span class="p">,</span>
+                    <span class="n">samples_sub_directory</span><span class="o">=</span><span class="s2">&quot;VOCdevkit/VOC2012/JPEGImages&quot;</span><span class="p">,</span>
+                    <span class="n">targets_sub_directory</span><span class="o">=</span><span class="s2">&quot;VOCdevkit/VOC2012/SegmentationClass&quot;</span><span class="p">,</span>
+                    <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span>
+                <span class="p">),</span>
+                <span class="n">PascalAUG2012SegmentationDataSet</span><span class="p">(</span>
+                    <span class="n">list_file</span><span class="o">=</span><span class="s2">&quot;VOCaug/dataset/aug.txt&quot;</span><span class="p">,</span> <span class="n">samples_sub_directory</span><span class="o">=</span><span class="s2">&quot;VOCaug/dataset/img&quot;</span><span class="p">,</span> <span class="n">targets_sub_directory</span><span class="o">=</span><span class="s2">&quot;VOCaug/dataset/cls&quot;</span><span class="p">,</span> <span class="o">**</span><span class="n">kwarg
+                <span class="p">),</span>
+            <span class="p">]</span>
+        <span class="p">)</span></div>
 </pre></div>
 </pre></div>
 
 
            </div>
            </div>
@@ -213,4 +307,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.datasets.segmentation_datasets.segmentation_dataset &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.datasets.segmentation_datasets.segmentation_dataset &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
+        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
+        <script src="../../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -87,71 +89,44 @@
              
              
   <h1>Source code for super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</h1><div class="highlight"><pre>
   <h1>Source code for super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</h1><div class="highlight"><pre>
 <span></span><span class="kn">import</span> <span class="nn">os</span>
 <span></span><span class="kn">import</span> <span class="nn">os</span>
-<span class="kn">import</span> <span class="nn">torch</span>
-<span class="kn">import</span> <span class="nn">random</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">Iterable</span>
+
 <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
 <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
-<span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
-<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Callable</span>
+<span class="kn">import</span> <span class="nn">torch</span>
 <span class="kn">import</span> <span class="nn">torchvision.transforms</span> <span class="k">as</span> <span class="nn">transform</span>
 <span class="kn">import</span> <span class="nn">torchvision.transforms</span> <span class="k">as</span> <span class="nn">transform</span>
 <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
 <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
+<span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
 
 
 <span class="kn">from</span> <span class="nn">super_gradients.common.decorators.factory_decorator</span> <span class="kn">import</span> <span class="n">resolve_param</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.decorators.factory_decorator</span> <span class="kn">import</span> <span class="n">resolve_param</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.factories.transforms_factory</span> <span class="kn">import</span> <span class="n">TransformsFactory</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.factories.transforms_factory</span> <span class="kn">import</span> <span class="n">TransformsFactory</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.sg_dataset</span> <span class="kn">import</span> <span class="n">DirectoryDataSet</span><span class="p">,</span> <span class="n">ListDataset</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.sg_dataset</span> <span class="kn">import</span> <span class="n">DirectoryDataSet</span><span class="p">,</span> <span class="n">ListDataset</span>
-<span class="kn">from</span> <span class="nn">super_gradients.training.transforms.transforms</span> <span class="kn">import</span> <span class="n">RandomFlip</span><span class="p">,</span> <span class="n">Rescale</span><span class="p">,</span> <span class="n">RandomRescale</span><span class="p">,</span> <span class="n">RandomRotate</span><span class="p">,</span> \
-    <span class="n">CropImageAndMask</span><span class="p">,</span> <span class="n">RandomGaussianBlur</span><span class="p">,</span> <span class="n">PadShortToCropSize</span>
 
 
 
 
-<div class="viewcode-block" id="SegmentationDataSet"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.SegmentationDataSet">[docs]</a><span class="k">class</span> <span class="nc">SegmentationDataSet</span><span class="p">(</span><span class="n">DirectoryDataSet</span><span class="p">,</span> <span class="n">ListDataset</span><span class="p">):</span>
+<div class="viewcode-block" id="SegmentationDataSet"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.SegmentationDataSet">[docs]</a><span class="k">class</span> <span class="nc">SegmentationDataSet</span><span class="p">(</span><span class="n">DirectoryDataSet</span><span class="p">,</span> <span class="n">ListDataset</span><span class="p">):</span>
 
 
-    <span class="nd">@resolve_param</span><span class="p">(</span><span class="s1">&#39;image_mask_transforms&#39;</span><span class="p">,</span> <span class="n">factory</span><span class="o">=</span><span class="n">TransformsFactory</span><span class="p">())</span>
-    <span class="nd">@resolve_param</span><span class="p">(</span><span class="s1">&#39;image_mask_transforms_aug&#39;</span><span class="p">,</span> <span class="n">factory</span><span class="o">=</span><span class="n">TransformsFactory</span><span class="p">())</span>
+    <span class="nd">@resolve_param</span><span class="p">(</span><span class="s1">&#39;transforms&#39;</span><span class="p">,</span> <span class="n">factory</span><span class="o">=</span><span class="n">TransformsFactory</span><span class="p">())</span>
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">root</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">list_file</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">samples_sub_directory</span><span class="p">:</span> <span class="nb">str</s
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">root</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">list_file</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">samples_sub_directory</span><span class="p">:</span> <span class="nb">str</s
                  <span class="n">targets_sub_directory</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
                  <span class="n">targets_sub_directory</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
-                 <span class="n">img_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">608</span><span class="p">,</span> <span class="n">crop_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">512</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">16</span><span class="p">,</spa
-                 <span class="n">dataset_hyper_params</span><span class="p">:</span> <span class="nb">dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
-                 <span class="n">cache_labels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">cache_images</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">sample_loader</span><span class="p">:</span> <span class="n">Callable</span> <span class="o">=</span> <span class="kc">None</span>
-                 <span class="n">target_loader</span><span class="p">:</span> <span class="n">Callable</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">collate_fn</span><span class="p">:</span> <span class="n">Callable</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">target_extension</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;.png
-                 <span class="n">image_mask_transforms</span><span class="p">:</span> <span class="n">transform</span><span class="o">.</span><span class="n">Compose</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">image_mask_transforms_aug</span><span class="p">:</span> <span class="n">transform</span><span class="o">.</span><span class="n">Compose</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
+                 <span class="n">cache_labels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">cache_images</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+                 <span class="n">collate_fn</span><span class="p">:</span> <span class="n">Callable</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">target_extension</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;.png&#39;</span><span class="p">,</span>
+                 <span class="n">transforms</span><span class="p">:</span> <span class="n">Iterable</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        SegmentationDataSet</span>
 <span class="sd">        SegmentationDataSet</span>
-<span class="sd">                                * Please use self.augment == True only for training</span>
-
 <span class="sd">            :param root:                        Root folder of the Data Set</span>
 <span class="sd">            :param root:                        Root folder of the Data Set</span>
 <span class="sd">            :param list_file:                   Path to the file with the samples list</span>
 <span class="sd">            :param list_file:                   Path to the file with the samples list</span>
 <span class="sd">            :param samples_sub_directory:       name of the samples sub-directory</span>
 <span class="sd">            :param samples_sub_directory:       name of the samples sub-directory</span>
 <span class="sd">            :param targets_sub_directory:       name of the targets sub-directory</span>
 <span class="sd">            :param targets_sub_directory:       name of the targets sub-directory</span>
-<span class="sd">            :param img_size:                    Image size of the Model that uses this Data Set</span>
-<span class="sd">            :param crop_size:                   The size of the cropped image</span>
-<span class="sd">            :param batch_size:                  Batch Size of the Model that uses this Data Set</span>
-<span class="sd">            :param augment:                     True / False flag to allow Augmentation</span>
-<span class="sd">            :param dataset_hyper_params:        Any hyper params required for the data set</span>
 <span class="sd">            :param cache_labels:                &quot;Caches&quot; the labels -&gt; Pre-Loads to memory as a list</span>
 <span class="sd">            :param cache_labels:                &quot;Caches&quot; the labels -&gt; Pre-Loads to memory as a list</span>
 <span class="sd">            :param cache_images:                &quot;Caches&quot; the images -&gt; Pre-Loads to memory as a list</span>
 <span class="sd">            :param cache_images:                &quot;Caches&quot; the images -&gt; Pre-Loads to memory as a list</span>
-<span class="sd">            :param sample_loader:               A function that specifies how to load a sample</span>
-<span class="sd">            :param target_loader:               A function that specifies how to load a target</span>
 <span class="sd">            :param collate_fn:                  collate_fn func to process batches for the Data Loader</span>
 <span class="sd">            :param collate_fn:                  collate_fn func to process batches for the Data Loader</span>
-<span class="sd">            :param target_extension:            file extension of the targets (defualt is .png for PASCAL VOC 2012)</span>
-<span class="sd">            :param image_mask_transforms        transforms to be applied on image and mask when augment=False</span>
-<span class="sd">            :param image_mask_transforms_aug    transforms to be applied on image and mask when augment=True</span>
+<span class="sd">            :param target_extension:            file extension of the targets (default is .png for PASCAL VOC 2012)</span>
+<span class="sd">            :param transforms:                  transforms to be applied on image and mask</span>
+
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">samples_sub_directory</span> <span class="o">=</span> <span class="n">samples_sub_directory</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">samples_sub_directory</span> <span class="o">=</span> <span class="n">samples_sub_directory</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">targets_sub_directory</span> <span class="o">=</span> <span class="n">targets_sub_directory</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">targets_sub_directory</span> <span class="o">=</span> <span class="n">targets_sub_directory</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">dataset_hyperparams</span> <span class="o">=</span> <span class="n">dataset_hyper_params</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">cache_labels</span> <span class="o">=</span> <span class="n">cache_labels</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">cache_labels</span> <span class="o">=</span> <span class="n">cache_labels</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">cache_images</span> <span class="o">=</span> <span class="n">cache_images</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">cache_images</span> <span class="o">=</span> <span class="n">cache_images</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">img_size</span> <span class="o">=</span> <span class="n">img_size</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span> <span class="o">=</span> <span class="n">crop_size</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">augment</span> <span class="o">=</span> <span class="n">augment</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">batch_index</span> <span class="o">=</span> <span class="kc">None</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">total_batches_num</span> <span class="o">=</span> <span class="kc">None</span>
-
-        <span class="c1"># ENABLES USING CUSTOM SAMPLE/TARGET LOADERS</span>
-        <span class="k">if</span> <span class="n">sample_loader</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
-            <span class="bp">self</span><span class="o">.</span><span class="n">sample_loader</span> <span class="o">=</span> <span class="n">sample_loader</span>
-        <span class="k">if</span> <span class="n">target_loader</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
-            <span class="bp">self</span><span class="o">.</span><span class="n">target_loader</span> <span class="o">=</span> <span class="n">target_loader</span>
 
 
         <span class="c1"># CREATE A DIRECTORY DATASET OR A LIST DATASET BASED ON THE list_file INPUT VARIABLE</span>
         <span class="c1"># CREATE A DIRECTORY DATASET OR A LIST DATASET BASED ON THE list_file INPUT VARIABLE</span>
         <span class="k">if</span> <span class="n">list_file</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">list_file</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
@@ -165,26 +140,8 @@
                                       <span class="n">sample_loader</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sample_loader</span><span class="p">,</span> <span class="n">sample_transform</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sample_transform</span><span class="p">,</span>
                                       <span class="n">sample_loader</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sample_loader</span><span class="p">,</span> <span class="n">sample_transform</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sample_transform</span><span class="p">,</span>
                                       <span class="n">target_loader</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">target_loader</span><span class="p">,</span> <span class="n">target_transform</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">target_transform</span><span class="p">,</span>
                                       <span class="n">target_loader</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">target_loader</span><span class="p">,</span> <span class="n">target_transform</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">target_transform</span><span class="p">,</span>
                                       <span class="n">collate_fn</span><span class="o">=</span><span class="n">collate_fn</span><span class="p">)</span>
                                       <span class="n">collate_fn</span><span class="o">=</span><span class="n">collate_fn</span><span class="p">)</span>
-        <span class="c1"># DEFAULT TRANSFORMS</span>
-        <span class="c1"># FIXME - Rescale before RandomRescale is kept for legacy support, consider removing it like most implementation</span>
-        <span class="c1">#  papers regimes.</span>
-        <span class="n">default_image_mask_transforms_aug</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span><span class="n">RandomFlip</span><span class="p">(),</span>
-                                                               <span class="n">Rescale</span><span class="p">(</span><span class="n">short_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">img_size</span><span class="p">),</span>
-                                                               <span class="n">RandomRescale</span><span class="p">(</span><span class="n">scales</span><span class="o">=</span><span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">)),</span>
-                                                               <span class="n">RandomRotate</span><span class="p">(),</span>
-                                                               <span class="n">PadShortToCropSize</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">),</span>
-                                                               <span class="n">CropImageAndMask</span><span class="p">(</span><span class="n">crop_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">,</span>
-                                                                                <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;random&quot;</span><span class="p">),</span>
-                                                               <span class="n">RandomGaussianBlur</span><span class="p">()])</span>
-
-        <span class="bp">self</span><span class="o">.</span><span class="n">image_mask_transforms_aug</span> <span class="o">=</span> <span class="n">image_mask_transforms_aug</span> <span class="ow">or</span> <span class="n">default_image_mask_transforms_aug</span>
-        <span class="c1"># FIXME: CROP SIZE CANNOT BE PASSED WHEN LIST</span>
-        <span class="k">if</span> <span class="n">image_mask_transforms</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
-            <span class="n">image_mask_transforms</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span><span class="n">Rescale</span><span class="p">(</span><span class="n">short_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">),</span>
-                                                       <span class="n">CropImageAndMask</span><span class="p">(</span><span class="n">crop_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;center&quot;</span><span class="p">)</span>
-                                                       <span class="p">])</span>
-
-        <span class="bp">self</span><span class="o">.</span><span class="n">image_mask_transforms</span> <span class="o">=</span> <span class="n">image_mask_transforms</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">transforms</span> <span class="o">=</span> <span class="n">transform</span><span class="o">.</span><span class="n">Compose</span><span class="p">(</span><span class="n">transforms</span> <span class="k">if</span> <span class="n">transforms</span> <span class="k">else</span> <span class="p">[])</span>
 
 
     <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span>
     <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span>
         <span class="n">sample_path</span><span class="p">,</span> <span class="n">target_path</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
         <span class="n">sample_path</span><span class="p">,</span> <span class="n">target_path</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
@@ -206,7 +163,7 @@
 
 
         <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">sample_transform</span><span class="p">(</span><span class="n">sample</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">target_transform</span><span class="p">(</span><span class="n">target</span><span class="p">)</span>
         <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">sample_transform</span><span class="p">(</span><span class="n">sample</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">target_transform</span><span class="p">(</span><span class="n">target</span><span class="p">)</span>
 
 
-<div class="viewcode-block" id="SegmentationDataSet.sample_loader"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.SegmentationDataSet.sample_loader">[docs]</a>    <span class="nd">@staticmethod</span>
+<div class="viewcode-block" id="SegmentationDataSet.sample_loader"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.SegmentationDataSet.sample_loader">[docs]</a>    <span class="nd">@staticmethod</span>
     <span class="k">def</span> <span class="nf">sample_loader</span><span class="p">(</span><span class="n">sample_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Image</span><span class="p">:</span>
     <span class="k">def</span> <span class="nf">sample_loader</span><span class="p">(</span><span class="n">sample_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Image</span><span class="p">:</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        sample_loader - Loads a dataset image from path using PIL</span>
 <span class="sd">        sample_loader - Loads a dataset image from path using PIL</span>
@@ -216,7 +173,7 @@
         <span class="n">image</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">sample_path</span><span class="p">)</span><span class="o">.</span><span class="n">convert</span><span class="p">(</span><span class="s1">&#39;RGB&#39;</span><span class="p">)</span>
         <span class="n">image</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">sample_path</span><span class="p">)</span><span class="o">.</span><span class="n">convert</span><span class="p">(</span><span class="s1">&#39;RGB&#39;</span><span class="p">)</span>
         <span class="k">return</span> <span class="n">image</span></div>
         <span class="k">return</span> <span class="n">image</span></div>
 
 
-<div class="viewcode-block" id="SegmentationDataSet.sample_transform"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.SegmentationDataSet.sample_transform">[docs]</a>    <span class="nd">@staticmethod</span>
+<div class="viewcode-block" id="SegmentationDataSet.sample_transform"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.SegmentationDataSet.sample_transform">[docs]</a>    <span class="nd">@staticmethod</span>
     <span class="k">def</span> <span class="nf">sample_transform</span><span class="p">(</span><span class="n">image</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">sample_transform</span><span class="p">(</span><span class="n">image</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        sample_transform - Transforms the sample image</span>
 <span class="sd">        sample_transform - Transforms the sample image</span>
@@ -230,7 +187,7 @@
 
 
         <span class="k">return</span> <span class="n">sample_transform</span><span class="p">(</span><span class="n">image</span><span class="p">)</span></div>
         <span class="k">return</span> <span class="n">sample_transform</span><span class="p">(</span><span class="n">image</span><span class="p">)</span></div>
 
 
-<div class="viewcode-block" id="SegmentationDataSet.target_loader"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.SegmentationDataSet.target_loader">[docs]</a>    <span class="nd">@staticmethod</span>
+<div class="viewcode-block" id="SegmentationDataSet.target_loader"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.SegmentationDataSet.target_loader">[docs]</a>    <span class="nd">@staticmethod</span>
     <span class="k">def</span> <span class="nf">target_loader</span><span class="p">(</span><span class="n">target_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Image</span><span class="p">:</span>
     <span class="k">def</span> <span class="nf">target_loader</span><span class="p">(</span><span class="n">target_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Image</span><span class="p">:</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        target_loader</span>
 <span class="sd">        target_loader</span>
@@ -240,7 +197,7 @@
         <span class="n">target</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">target_path</span><span class="p">)</span>
         <span class="n">target</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">target_path</span><span class="p">)</span>
         <span class="k">return</span> <span class="n">target</span></div>
         <span class="k">return</span> <span class="n">target</span></div>
 
 
-<div class="viewcode-block" id="SegmentationDataSet.target_transform"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.SegmentationDataSet.target_transform">[docs]</a>    <span class="nd">@staticmethod</span>
+<div class="viewcode-block" id="SegmentationDataSet.target_transform"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.SegmentationDataSet.target_transform">[docs]</a>    <span class="nd">@staticmethod</span>
     <span class="k">def</span> <span class="nf">target_transform</span><span class="p">(</span><span class="n">target</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">target_transform</span><span class="p">(</span><span class="n">target</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        target_transform - Transforms the sample image</span>
 <span class="sd">        target_transform - Transforms the sample image</span>
@@ -258,9 +215,6 @@
         <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="p">:</span>
         <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="p">:</span>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">_generate_samples_and_targets</span><span class="p">()</span>
             <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">_generate_samples_and_targets</span><span class="p">()</span>
 
 
-        <span class="bp">self</span><span class="o">.</span><span class="n">batch_index</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="p">))</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.<
-        <span class="bp">self</span><span class="o">.</span><span class="n">total_batches_num</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_index</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span>
-
         <span class="c1"># EXTRACT THE LABELS FROM THE TUPLES LIST</span>
         <span class="c1"># EXTRACT THE LABELS FROM THE TUPLES LIST</span>
         <span class="n">image_files</span><span class="p">,</span> <span class="n">label_files</span> <span class="o">=</span> <span class="nb">map</span><span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="p">))</span>
         <span class="n">image_files</span><span class="p">,</span> <span class="n">label_files</span> <span class="o">=</span> <span class="nb">map</span><span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="p">))</span>
         <span class="n">image_indices_to_remove</span> <span class="o">=</span> <span class="p">[]</span>
         <span class="n">image_indices_to_remove</span> <span class="o">=</span> <span class="p">[]</span>
@@ -309,66 +263,13 @@
             <span class="bp">self</span><span class="o">.</span><span class="n">label_files</span> <span class="o">=</span> <span class="p">[</span><span class="n">e</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">e</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">label_files</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">not</span> <sp
             <span class="bp">self</span><span class="o">.</span><span class="n">label_files</span> <span class="o">=</span> <span class="p">[</span><span class="n">e</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">e</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">label_files</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="ow">not</span> <sp
             <span class="bp">self</span><span class="o">.</span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">e</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">e</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">labels</span><span class="p">)</span> <span class="k">if</span> <span class="n
             <span class="bp">self</span><span class="o">.</span><span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="n">e</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">e</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">labels</span><span class="p">)</span> <span class="k">if</span> <span class="n
 
 
-    <span class="k">def</span> <span class="nf">_calculate_short_size</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">img</span><span class="p">):</span>
-        <span class="sd">&quot;&quot;&quot;</span>
-<span class="sd">        _calculate_crop</span>
-<span class="sd">        :param img:</span>
-<span class="sd">        :return:</span>
-<span class="sd">        &quot;&quot;&quot;</span>
-        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">augment</span><span class="p">:</span>
-            <span class="c1"># RANDOM SCALE (SHORT EDGE FROM 480 TO 720)</span>
-            <span class="n">short_size</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">img_size</span> <span class="o">*</span> <span class="mf">0.5</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span
-        <span class="k">else</span><span class="p">:</span>
-            <span class="n">short_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span>
-
-        <span class="n">w</span><span class="p">,</span> <span class="n">h</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">size</span>
-        <span class="k">if</span> <span class="n">w</span> <span class="o">&gt;</span> <span class="n">h</span><span class="p">:</span>
-            <span class="n">oh</span> <span class="o">=</span> <span class="n">short_size</span>
-            <span class="n">ow</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">*</span> <span class="n">w</span> <span class="o">*</span> <span class="n">oh</span> <span class="o">/</span> <span class="n">h</span><span class="p">)</span>
-        <span class="k">else</span><span class="p">:</span>
-            <span class="n">ow</span> <span class="o">=</span> <span class="n">short_size</span>
-            <span class="n">oh</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">*</span> <span class="n">h</span> <span class="o">*</span> <span class="n">ow</span> <span class="o">/</span> <span class="n">w</span><span class="p">)</span>
-
-        <span class="k">return</span> <span class="n">oh</span><span class="p">,</span> <span class="n">ow</span><span class="p">,</span> <span class="n">short_size</span>
-
-    <span class="k">def</span> <span class="nf">_get_center_crop</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">):</span>
-        <span class="sd">&quot;&quot;&quot;</span>
-
-<span class="sd">        :param w:</span>
-<span class="sd">        :param h:</span>
-<span class="sd">        :return:</span>
-<span class="sd">        &quot;&quot;&quot;</span>
-        <span class="c1"># CENTER CROP</span>
-        <span class="n">x1</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="nb">round</span><span class="p">((</span><span class="n">w</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">)</span> <span class="o">/</span> <span class="mf">2.</span><span class="p">))</span>
-        <span class="n">y1</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="nb">round</span><span class="p">((</span><span class="n">h</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">)</span> <span class="o">/</span> <span class="mf">2.</span><span class="p">))</span>
-
-        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">augment</span><span class="p">:</span>
-            <span class="c1"># RANDOM CROP CROP_SIZE</span>
-            <span class="n">x1</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">w</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">)</span>
-            <span class="n">y1</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">h</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">)</span>
-
-        <span class="k">return</span> <span class="n">x1</span><span class="p">,</span> <span class="n">y1</span>
-
     <span class="k">def</span> <span class="nf">_transform_image_and_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">image</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">tuple</span><span class="p">:</span>
     <span class="k">def</span> <span class="nf">_transform_image_and_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">image</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">tuple</span><span class="p">:</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
-<span class="sd">        _transform -  Transforms the input (image, mask) in the following order:</span>
-<span class="sd">                                1. FLIP (if augment==true)</span>
-<span class="sd">                                2. RESIZE</span>
-<span class="sd">                                3. ROTATE (if augment==true)</span>
-<span class="sd">                                4. CROP</span>
-<span class="sd">                                5. GAUSSIAN BLUR (if augment==true)</span>
-
-<span class="sd">                            * Please use self.augment == True only for training</span>
-
 <span class="sd">            :param image:           The input image</span>
 <span class="sd">            :param image:           The input image</span>
 <span class="sd">            :param mask:            The input mask</span>
 <span class="sd">            :param mask:            The input mask</span>
 <span class="sd">            :return:                The transformed image, mask</span>
 <span class="sd">            :return:                The transformed image, mask</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
-        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">augment</span><span class="p">:</span>
-            <span class="n">transformed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_mask_transforms_aug</span><span class="p">({</span><span class="s2">&quot;image&quot;</span><span class="p">:</span> <span class="n">image</span><span class="p">,</span> <span class="s2">&quot;mask&quot;</span><span class="p">:</span> <span class="n">mask</span><span class="p">})</span>
-        <span class="k">else</span><span class="p">:</span>
-            <span class="n">transformed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_mask_transforms</span><span class="p">({</span><span class="s2">&quot;image&quot;</span><span class="p">:</span> <span class="n">image</span><span class="p">,</span> <span class="s2">&quot;mask&quot;</span><span class="p">:</span> <span class="n">mask</span><span class="p">})</span>
-
+        <span class="n">transformed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">transforms</span><span class="p">({</span><span class="s2">&quot;image&quot;</span><span class="p">:</span> <span class="n">image</span><span class="p">,</span> <span class="s2">&quot;mask&quot;</span><span class="p">:</span> <span class="n">mask</span><span class="p">})</span>
         <span class="k">return</span> <span class="n">transformed</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">transformed</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span></div>
         <span class="k">return</span> <span class="n">transformed</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">transformed</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span></div>
 </pre></div>
 </pre></div>
 
 
@@ -399,4 +300,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.datasets.segmentation_datasets.supervisely_persons_segmentation &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.datasets.segmentation_datasets.supervisely_persons_segmentation &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/jquery.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
         <script src="../../../../../_static/underscore.js"></script>
+        <script src="../../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
         <script src="../../../../../_static/doctools.js"></script>
+        <script src="../../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <script src="../../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -92,7 +94,7 @@
 <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</span> <span class="kn">import</span> <span class="n">SegmentationDataSet</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.segmentation_datasets.segmentation_dataset</span> <span class="kn">import</span> <span class="n">SegmentationDataSet</span>
 
 
 
 
-<div class="viewcode-block" id="SuperviselyPersonsDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.datasets.segmentation_datasets.html#super_gradients.training.datasets.segmentation_datasets.SuperviselyPersonsDataset">[docs]</a><span class="k">class</span> <span class="nc">SuperviselyPersonsDataset</span><span class="p">(</span><span class="n">SegmentationDataSet</span><span class="p">):</span>
+<div class="viewcode-block" id="SuperviselyPersonsDataset"><a class="viewcode-back" href="../../../../../super_gradients.training.html#super_gradients.training.datasets.SuperviselyPersonsDataset">[docs]</a><span class="k">class</span> <span class="nc">SuperviselyPersonsDataset</span><span class="p">(</span><span class="n">SegmentationDataSet</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    SuperviselyPersonsDataset - Segmentation Data Set Class for Supervisely Persons Segmentation Data Set,</span>
 <span class="sd">    SuperviselyPersonsDataset - Segmentation Data Set Class for Supervisely Persons Segmentation Data Set,</span>
 <span class="sd">    main resolution of dataset: (600 x 800).</span>
 <span class="sd">    main resolution of dataset: (600 x 800).</span>
@@ -126,7 +128,8 @@
                     <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">sample_path</span><span class="p">,</span> <span class="n">target_path</span><span class="p">))</span>
                     <span class="bp">self</span><span class="o">.</span><span class="n">samples_targets_tuples_list</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">sample_path</span><span class="p">,</span> <span class="n">target_path</span><span class="p">))</span>
                 <span class="k">else</span><span class="p">:</span>
                 <span class="k">else</span><span class="p">:</span>
                     <span class="k">raise</span> <span class="ne">AssertionError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Sample and/or target file(s) not found or in illegal format &quot;</span>
                     <span class="k">raise</span> <span class="ne">AssertionError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Sample and/or target file(s) not found or in illegal format &quot;</span>
-                                         <span class="sa">f</span><span class="s2">&quot;(sample path: </span><span class="si">{</span><span class="n">sample_path</span><span class="si">}</span><span class="s2">, target path: </span><span class="si">{</span><span class="n">target_path</span><span class="si">}</span><span class="s2">)&quot;</span><span class="p">)</span></div>
+                                         <span class="sa">f</span><span class="s2">&quot;(sample path: </span><span class="si">{</span><span class="n">sample_path</span><span class="si">}</span><span class="s2">, target path: </span><span class="si">{</span><span class="n">target_path</span><span class="si">}</span><span class="s2">)&quot;</span><span class="p">)</span>
+        <span class="nb">super</span><span class="p">(</span><span class="n">SuperviselyPersonsDataset</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">_generate_samples_and_targets</span><span class="p">()</span></div>
 </pre></div>
 </pre></div>
 
 
            </div>
            </div>
@@ -156,4 +159,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.datasets.sg_dataset &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.datasets.sg_dataset &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -97,7 +99,7 @@
 <span class="n">IMG_EXTENSIONS</span> <span class="o">=</span> <span class="p">(</span><span class="s1">&#39;.jpg&#39;</span><span class="p">,</span> <span class="s1">&#39;.jpeg&#39;</span><span class="p">,</span> <span class="s1">&#39;.png&#39;</span><span class="p">,</span> <span class="s1">&#39;.ppm&#39;</span><span class="p">,</span> <span class="s1">&#39;.bmp&#39;</span><span class="p">,</span> <span class="s1">&#39;.pgm&#39;</span><span class="p">,</span> <span class="s1">&#39;.tif&#39;</
 <span class="n">IMG_EXTENSIONS</span> <span class="o">=</span> <span class="p">(</span><span class="s1">&#39;.jpg&#39;</span><span class="p">,</span> <span class="s1">&#39;.jpeg&#39;</span><span class="p">,</span> <span class="s1">&#39;.png&#39;</span><span class="p">,</span> <span class="s1">&#39;.ppm&#39;</span><span class="p">,</span> <span class="s1">&#39;.bmp&#39;</span><span class="p">,</span> <span class="s1">&#39;.pgm&#39;</span><span class="p">,</span> <span class="s1">&#39;.tif&#39;</
 
 
 
 
-<div class="viewcode-block" id="BaseSgVisionDataset"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.BaseSgVisionDataset">[docs]</a><span class="k">class</span> <span class="nc">BaseSgVisionDataset</span><span class="p">(</span><span class="n">VisionDataset</span><span class="p">):</span>
+<span class="k">class</span> <span class="nc">BaseSgVisionDataset</span><span class="p">(</span><span class="n">VisionDataset</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    BaseSgVisionDataset</span>
 <span class="sd">    BaseSgVisionDataset</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
@@ -160,16 +162,16 @@
 
 
         <span class="k">return</span> <span class="kc">False</span>
         <span class="k">return</span> <span class="kc">False</span>
 
 
-<div class="viewcode-block" id="BaseSgVisionDataset.numpy_loader_func"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.BaseSgVisionDataset.numpy_loader_func">[docs]</a>    <span class="nd">@staticmethod</span>
+    <span class="nd">@staticmethod</span>
     <span class="k">def</span> <span class="nf">numpy_loader_func</span><span class="p">(</span><span class="n">path</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">numpy_loader_func</span><span class="p">(</span><span class="n">path</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        _numpy_loader_func - Uses numpy load func</span>
 <span class="sd">        _numpy_loader_func - Uses numpy load func</span>
 <span class="sd">            :param path:</span>
 <span class="sd">            :param path:</span>
 <span class="sd">            :return:</span>
 <span class="sd">            :return:</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
-        <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">)</span></div>
+        <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">)</span>
 
 
-<div class="viewcode-block" id="BaseSgVisionDataset.text_file_loader_func"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.BaseSgVisionDataset.text_file_loader_func">[docs]</a>    <span class="nd">@staticmethod</span>
+    <span class="nd">@staticmethod</span>
     <span class="k">def</span> <span class="nf">text_file_loader_func</span><span class="p">(</span><span class="n">text_file_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">inline_splitter</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39; &#39;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">:</span>
     <span class="k">def</span> <span class="nf">text_file_loader_func</span><span class="p">(</span><span class="n">text_file_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">inline_splitter</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39; &#39;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">:</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        text_file_loader_func - Uses a line by line based code to get vectorized data from a text-based file</span>
 <span class="sd">        text_file_loader_func - Uses a line by line based code to get vectorized data from a text-based file</span>
@@ -184,10 +186,10 @@
         <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">text_file_path</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">text_file</span><span class="p">:</span>
         <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">text_file_path</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">text_file</span><span class="p">:</span>
             <span class="n">targets_list</span> <span class="o">=</span> <span class="p">[</span><span class="nb">tuple</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="nb">float</span><span class="p">,</span> <span class="n">line</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">inline_splitter</span><span class="p">)))</span> <span class="k">for</span> <span class="n">line</span> <span class="ow">in</
             <span class="n">targets_list</span> <span class="o">=</span> <span class="p">[</span><span class="nb">tuple</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="nb">float</span><span class="p">,</span> <span class="n">line</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">inline_splitter</span><span class="p">)))</span> <span class="k">for</span> <span class="n">line</span> <span class="ow">in</
 
 
-        <span class="k">return</span> <span class="n">targets_list</span></div></div>
+        <span class="k">return</span> <span class="n">targets_list</span>
 
 
 
 
-<div class="viewcode-block" id="DirectoryDataSet"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.DirectoryDataSet">[docs]</a><span class="k">class</span> <span class="nc">DirectoryDataSet</span><span class="p">(</span><span class="n">BaseSgVisionDataset</span><span class="p">):</span>
+<div class="viewcode-block" id="DirectoryDataSet"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.datasets.DirectoryDataSet">[docs]</a><span class="k">class</span> <span class="nc">DirectoryDataSet</span><span class="p">(</span><span class="n">BaseSgVisionDataset</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    DirectoryDataSet - A PyTorch Vision Data Set extension that receives a root Dir and two separate sub directories:</span>
 <span class="sd">    DirectoryDataSet - A PyTorch Vision Data Set extension that receives a root Dir and two separate sub directories:</span>
 <span class="sd">                        - Sub-Directory for Samples</span>
 <span class="sd">                        - Sub-Directory for Samples</span>
@@ -279,7 +281,7 @@
                 <span class="nb">print</span><span class="p">(</span><span class="vm">__name__</span> <span class="o">+</span> <span class="s1">&#39; There are &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">missing_files_counter</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39; missing  &#39;</span> <span class="o">+</span> <span class="n">counter_name</span><span class="p">)</span></div>
                 <span class="nb">print</span><span class="p">(</span><span class="vm">__name__</span> <span class="o">+</span> <span class="s1">&#39; There are &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">missing_files_counter</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39; missing  &#39;</span> <span class="o">+</span> <span class="n">counter_name</span><span class="p">)</span></div>
 
 
 
 
-<div class="viewcode-block" id="ListDataset"><a class="viewcode-back" href="../../../../super_gradients.training.datasets.html#super_gradients.training.datasets.ListDataset">[docs]</a><span class="k">class</span> <span class="nc">ListDataset</span><span class="p">(</span><span class="n">BaseSgVisionDataset</span><span class="p">):</span>
+<div class="viewcode-block" id="ListDataset"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.datasets.ListDataset">[docs]</a><span class="k">class</span> <span class="nc">ListDataset</span><span class="p">(</span><span class="n">BaseSgVisionDataset</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    ListDataset - A PyTorch Vision Data Set extension that receives a file with FULL PATH to each of the samples.</span>
 <span class="sd">    ListDataset - A PyTorch Vision Data Set extension that receives a file with FULL PATH to each of the samples.</span>
 <span class="sd">                  Then, the assumption is that for every sample, there is a * matching target * in the same</span>
 <span class="sd">                  Then, the assumption is that for every sample, there is a * matching target * in the same</span>
@@ -383,4 +385,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.exceptions.dataset_exceptions &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.exceptions.dataset_exceptions</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.exceptions.dataset_exceptions</h1><div class="highlight"><pre>
  83. <span></span>
  84. <div class="viewcode-block" id="IllegalDatasetParameterException"><a class="viewcode-back" href="../../../../super_gradients.training.exceptions.html#super_gradients.training.exceptions.dataset_exceptions.IllegalDatasetParameterException">[docs]</a><span class="k">class</span> <span class="nc">IllegalDatasetParameterException</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
  85. <span class="sd">&quot;&quot;&quot;</span>
  86. <span class="sd"> Exception raised illegal dataset param.</span>
  87. <span class="sd"> Attributes:</span>
  88. <span class="sd"> message -- explanation of the error</span>
  89. <span class="sd"> &quot;&quot;&quot;</span>
  90. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">desc</span><span class="p">):</span>
  91. <span class="bp">self</span><span class="o">.</span><span class="n">message</span> <span class="o">=</span> <span class="s2">&quot;Unsupported dataset parameter format: &quot;</span> <span class="o">+</span> <span class="n">desc</span>
  92. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">message</span><span class="p">)</span></div>
  93. <div class="viewcode-block" id="EmptyDatasetException"><a class="viewcode-back" href="../../../../super_gradients.training.exceptions.html#super_gradients.training.exceptions.dataset_exceptions.EmptyDatasetException">[docs]</a><span class="k">class</span> <span class="nc">EmptyDatasetException</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
  94. <span class="sd">&quot;&quot;&quot;</span>
  95. <span class="sd"> Exception raised when a dataset does not have any image for a specific config</span>
  96. <span class="sd"> Attributes:</span>
  97. <span class="sd"> message -- explanation of the error</span>
  98. <span class="sd"> &quot;&quot;&quot;</span>
  99. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">desc</span><span class="p">):</span>
  100. <span class="bp">self</span><span class="o">.</span><span class="n">message</span> <span class="o">=</span> <span class="s2">&quot;Empty Dataset: &quot;</span> <span class="o">+</span> <span class="n">desc</span>
  101. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">message</span><span class="p">)</span></div>
  102. <div class="viewcode-block" id="UnsupportedBatchItemsFormat"><a class="viewcode-back" href="../../../../super_gradients.training.exceptions.html#super_gradients.training.exceptions.dataset_exceptions.UnsupportedBatchItemsFormat">[docs]</a><span class="k">class</span> <span class="nc">UnsupportedBatchItemsFormat</span><span class="p">(</span><span class="ne">ValueError</span><span class="p">):</span>
  103. <span class="sd">&quot;&quot;&quot;Exception raised illegal batch items returned from data loader.</span>
  104. <span class="sd"> Attributes:</span>
  105. <span class="sd"> message -- explanation of the error</span>
  106. <span class="sd"> &quot;&quot;&quot;</span>
  107. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  108. <span class="bp">self</span><span class="o">.</span><span class="n">message</span> <span class="o">=</span> <span class="s2">&quot;Batch items returned by the data loader expected format: </span><span class="se">\n</span><span class="s2">&quot;</span> \
  109. <span class="s2">&quot;1. torch.Tensor or tuple, s.t inputs = batch_items[0], targets = batch_items[1] and len(&quot;</span> \
  110. <span class="s2">&quot;batch_items) = 2 </span><span class="se">\n</span><span class="s2">&quot;</span> \
  111. <span class="s2">&quot;2. tuple: (inputs, targets, additional_batch_items)&quot;</span>
  112. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">message</span><span class="p">)</span></div>
  113. </pre></div>
  114. </div>
  115. </div>
  116. <footer>
  117. <hr/>
  118. <div role="contentinfo">
  119. <p>&#169; Copyright 2021, SuperGradients team.</p>
  120. </div>
  121. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  122. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  123. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  124. </footer>
  125. </div>
  126. </div>
  127. </section>
  128. </div>
  129. <script>
  130. jQuery(function () {
  131. SphinxRtdTheme.Navigation.enable(true);
  132. });
  133. </script>
  134. </body>
  135. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.exceptions.sg_model_exceptions &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.exceptions.sg_model_exceptions</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.exceptions.sg_model_exceptions</h1><div class="highlight"><pre>
  83. <span></span>
  84. <div class="viewcode-block" id="UnsupportedTrainingParameterFormat"><a class="viewcode-back" href="../../../../super_gradients.training.exceptions.html#super_gradients.training.exceptions.sg_model_exceptions.UnsupportedTrainingParameterFormat">[docs]</a><span class="k">class</span> <span class="nc">UnsupportedTrainingParameterFormat</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
  85. <span class="sd">&quot;&quot;&quot;Exception raised illegal training param format.</span>
  86. <span class="sd"> Attributes:</span>
  87. <span class="sd"> message -- explanation of the error</span>
  88. <span class="sd"> &quot;&quot;&quot;</span>
  89. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">desc</span><span class="p">):</span>
  90. <span class="bp">self</span><span class="o">.</span><span class="n">message</span> <span class="o">=</span> <span class="s2">&quot;Unsupported training parameter format: &quot;</span> <span class="o">+</span> <span class="n">desc</span>
  91. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">message</span><span class="p">)</span></div>
  92. <div class="viewcode-block" id="UnsupportedOptimizerFormat"><a class="viewcode-back" href="../../../../super_gradients.training.exceptions.html#super_gradients.training.exceptions.sg_model_exceptions.UnsupportedOptimizerFormat">[docs]</a><span class="k">class</span> <span class="nc">UnsupportedOptimizerFormat</span><span class="p">(</span><span class="n">UnsupportedTrainingParameterFormat</span><span class="p">):</span>
  93. <span class="sd">&quot;&quot;&quot;Exception raised illegal optimizer format.</span>
  94. <span class="sd"> Attributes:</span>
  95. <span class="sd"> message -- explanation of the error</span>
  96. <span class="sd"> &quot;&quot;&quot;</span>
  97. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  98. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
  99. <span class="s2">&quot;optimizer parameter expected one of [&#39;Adam&#39;,&#39;SGD&#39;,&#39;RMSProp&#39;], or torch.optim.Optimizer object&quot;</span><span class="p">)</span></div>
  100. <div class="viewcode-block" id="IllegalDataloaderInitialization"><a class="viewcode-back" href="../../../../super_gradients.training.exceptions.html#super_gradients.training.exceptions.sg_model_exceptions.IllegalDataloaderInitialization">[docs]</a><span class="k">class</span> <span class="nc">IllegalDataloaderInitialization</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
  101. <span class="sd">&quot;&quot;&quot;Exception raised illegal data loaders.</span>
  102. <span class="sd"> &quot;&quot;&quot;</span>
  103. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  104. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
  105. <span class="s2">&quot;train_loader, valid_loader and class parameters are required when initializing SgModel with data loaders&quot;</span><span class="p">)</span></div>
  106. </pre></div>
  107. </div>
  108. </div>
  109. <footer>
  110. <hr/>
  111. <div role="contentinfo">
  112. <p>&#169; Copyright 2021, SuperGradients team.</p>
  113. </div>
  114. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  115. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  116. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  117. </footer>
  118. </div>
  119. </div>
  120. </section>
  121. </div>
  122. <script>
  123. jQuery(function () {
  124. SphinxRtdTheme.Navigation.enable(true);
  125. });
  126. </script>
  127. </body>
  128. </html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.kd_model.kd_model &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.kd_trainer.kd_trainer &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -76,7 +78,7 @@
   <ul class="wy-breadcrumbs">
   <ul class="wy-breadcrumbs">
       <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
       <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
           <li><a href="../../../index.html">Module code</a> &raquo;</li>
           <li><a href="../../../index.html">Module code</a> &raquo;</li>
-      <li>super_gradients.training.kd_model.kd_model</li>
+      <li>super_gradients.training.kd_trainer.kd_trainer</li>
       <li class="wy-breadcrumbs-aside">
       <li class="wy-breadcrumbs-aside">
       </li>
       </li>
   </ul>
   </ul>
@@ -85,101 +87,85 @@
           <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
            <div itemprop="articleBody">
            <div itemprop="articleBody">
              
              
-  <h1>Source code for super_gradients.training.kd_model.kd_model</h1><div class="highlight"><pre>
-<span></span><span class="kn">import</span> <span class="nn">torch.nn</span>
-
+  <h1>Source code for super_gradients.training.kd_trainer.kd_trainer</h1><div class="highlight"><pre>
+<span></span><span class="kn">import</span> <span class="nn">hydra</span>
+<span class="kn">import</span> <span class="nn">torch.nn</span>
+<span class="kn">from</span> <span class="nn">omegaconf</span> <span class="kn">import</span> <span class="n">DictConfig</span>
+<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span>
+
+<span class="kn">from</span> <span class="nn">super_gradients.common</span> <span class="kn">import</span> <span class="n">MultiGPUMode</span>
+<span class="kn">from</span> <span class="nn">super_gradients.training.dataloaders</span> <span class="kn">import</span> <span class="n">dataloaders</span>
+<span class="kn">from</span> <span class="nn">super_gradients.training.models</span> <span class="kn">import</span> <span class="n">SgModule</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.models.all_architectures</span> <span class="kn">import</span> <span class="n">KD_ARCHITECTURES</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.models.all_architectures</span> <span class="kn">import</span> <span class="n">KD_ARCHITECTURES</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.models.kd_modules.kd_module</span> <span class="kn">import</span> <span class="n">KDModule</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.models.kd_modules.kd_module</span> <span class="kn">import</span> <span class="n">KDModule</span>
-<span class="kn">from</span> <span class="nn">super_gradients.training.sg_model</span> <span class="kn">import</span> <span class="n">SgModel</span>
+<span class="kn">from</span> <span class="nn">super_gradients.training.sg_trainer</span> <span class="kn">import</span> <span class="n">Trainer</span>
 <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span>
 <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
-<span class="kn">from</span> <span class="nn">super_gradients.training</span> <span class="kn">import</span> <span class="n">utils</span> <span class="k">as</span> <span class="n">core_utils</span>
+<span class="kn">from</span> <span class="nn">super_gradients.training</span> <span class="kn">import</span> <span class="n">utils</span> <span class="k">as</span> <span class="n">core_utils</span><span class="p">,</span> <span class="n">models</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.pretrained_models</span> <span class="kn">import</span> <span class="n">PRETRAINED_NUM_CLASSES</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.pretrained_models</span> <span class="kn">import</span> <span class="n">PRETRAINED_NUM_CLASSES</span>
-<span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">get_param</span>
+<span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">get_param</span><span class="p">,</span> <span class="n">HpmStruct</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.checkpoint_utils</span> <span class="kn">import</span> <span class="n">read_ckpt_state_dict</span><span class="p">,</span> \
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.checkpoint_utils</span> <span class="kn">import</span> <span class="n">read_ckpt_state_dict</span><span class="p">,</span> \
     <span class="n">load_checkpoint_to_model</span>
     <span class="n">load_checkpoint_to_model</span>
-<span class="kn">from</span> <span class="nn">super_gradients.training.exceptions.kd_model_exceptions</span> <span class="kn">import</span> <span class="n">ArchitectureKwargsException</span><span class="p">,</span> \
+<span class="kn">from</span> <span class="nn">super_gradients.training.exceptions.kd_trainer_exceptions</span> <span class="kn">import</span> <span class="n">ArchitectureKwargsException</span><span class="p">,</span> \
     <span class="n">UnsupportedKDArchitectureException</span><span class="p">,</span> <span class="n">InconsistentParamsException</span><span class="p">,</span> <span class="n">UnsupportedKDModelArgException</span><span class="p">,</span> \
     <span class="n">UnsupportedKDArchitectureException</span><span class="p">,</span> <span class="n">InconsistentParamsException</span><span class="p">,</span> <span class="n">UnsupportedKDModelArgException</span><span class="p">,</span> \
     <span class="n">TeacherKnowledgeException</span><span class="p">,</span> <span class="n">UndefinedNumClassesException</span>
     <span class="n">TeacherKnowledgeException</span><span class="p">,</span> <span class="n">UndefinedNumClassesException</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.callbacks</span> <span class="kn">import</span> <span class="n">KDModelMetricsUpdateCallback</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.callbacks</span> <span class="kn">import</span> <span class="n">KDModelMetricsUpdateCallback</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.ema</span> <span class="kn">import</span> <span class="n">KDModelEMA</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.ema</span> <span class="kn">import</span> <span class="n">KDModelEMA</span>
+<span class="kn">from</span> <span class="nn">super_gradients.training.utils.sg_trainer_utils</span> <span class="kn">import</span> <span class="n">parse_args</span>
+
 <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
 <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
 
 
 
 
-<div class="viewcode-block" id="KDModel"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.KDModel">[docs]</a><span class="k">class</span> <span class="nc">KDModel</span><span class="p">(</span><span class="n">SgModel</span><span class="p">):</span>
-    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
-        <span class="nb">super</span><span class="p">(</span><span class="n">KDModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
+<div class="viewcode-block" id="KDTrainer"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.KDTrainer">[docs]</a><span class="k">class</span> <span class="nc">KDTrainer</span><span class="p">(</span><span class="n">Trainer</span><span class="p">):</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">multi_gpu</span><span class="p">:</span> <span class="n">Union</span
+                 <span class="n">ckpt_root_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">experiment_name</span><span class="o">=</span><span class="n">experiment_name</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">multi_gpu</span><span class="o">=</span><span class="n">multi_gpu</span><span class="p">,</span> <span class=
         <span class="bp">self</span><span class="o">.</span><span class="n">student_architecture</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">student_architecture</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">teacher_architecture</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">teacher_architecture</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">student_arch_params</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">student_arch_params</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">teacher_arch_params</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">teacher_arch_params</span> <span class="o">=</span> <span class="kc">None</span>
 
 
-<div class="viewcode-block" id="KDModel.build_model"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.KDModel.build_model">[docs]</a>    <span class="k">def</span> <span class="nf">build_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
-                    <span class="c1"># noqa: C901 - too complex</span>
-                    <span class="n">architecture</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">KDModule</span><span class="p">]</span> <span class="o">=</span> <span class="s1">&#39;kd_module&#39;</span><span class="p">,</span>
-                    <span class="n">arch_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">checkpoint_params</span><span class="o">=</span><span class="p">{},</span>
-                    <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
+<div class="viewcode-block" id="KDTrainer.train_from_config"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.KDTrainer.train_from_config">[docs]</a>    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">train_from_config</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">cfg</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">DictConfig</span><span class="p">,</span> <span class="nb">dict</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
-<span class="sd">        :param architecture: (Union[str, KDModule]) Defines the network&#39;s architecture from models/KD_ARCHITECTURES</span>
-<span class="sd">         (default=&#39;kd_module&#39;)</span>
-
-<span class="sd">        :param arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc to be passed to kd</span>
-<span class="sd">            architecture class (discarded when architecture is KDModule instance)</span>
-
-<span class="sd">        :param checkpoint_params: (dict) A dictionary like object with the following keys/values:</span>
-
-<span class="sd">              student_pretrained_weights:   String describing the dataset of the pretrained weights (for example</span>
-<span class="sd">              &quot;imagenent&quot;) for the student network.</span>
-
-<span class="sd">              teacher_pretrained_weights:   String describing the dataset of the pretrained weights (for example</span>
-<span class="sd">              &quot;imagenent&quot;) for the teacher network.</span>
-
-<span class="sd">              teacher_checkpoint_path:    Local path to the teacher&#39;s checkpoint. Note that when passing pretrained_weights</span>
-<span class="sd">                                   through teacher_arch_params these weights will be overridden by the</span>
-<span class="sd">                                   pretrained checkpoint. (default=None)</span>
-
-<span class="sd">              load_kd_model_checkpoint:   Whether to load an entire KDModule checkpoint (used to continue KD training)</span>
-<span class="sd">               (default=False)</span>
-
-<span class="sd">              kd_model_source_ckpt_folder_name: Folder name to load an entire KDModule checkpoint from</span>
-<span class="sd">                (self.experiment_name if none is given) to resume KD training (default=None)</span>
+<span class="sd">        Trains according to cfg recipe configuration.</span>
 
 
-<span class="sd">              kd_model_external_checkpoint_path: The path to the external checkpoint to be loaded. Can be absolute or relative</span>
-<span class="sd">                                               (ie: path/to/checkpoint.pth). If provided, will automatically attempt to</span>
-<span class="sd">                                               load the checkpoint even if the load_checkpoint flag is not provided.</span>
-<span class="sd">                                               (deafult=None)</span>
-
-<span class="sd">        :keyword student_architecture: (Union[str, SgModule]) Defines the student&#39;s architecture from</span>
-<span class="sd">            models/ALL_ARCHITECTURES (when str), or directly defined the student network (when SgModule).</span>
-
-<span class="sd">        :keyword teacher_architecture: (Union[str, SgModule]) Defines the teacher&#39;s architecture from</span>
-<span class="sd">            models/ALL_ARCHITECTURES (when str), or directly defined the teacher network (when SgModule).</span>
-
-<span class="sd">        :keyword student_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for student</span>
-<span class="sd">            net. (deafult={})</span>
+<span class="sd">        @param cfg: The parsed DictConfig from yaml recipe files</span>
+<span class="sd">        @return: output of kd_trainer.train(...) (i.e results tuple)</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+        <span class="c1"># INSTANTIATE ALL OBJECTS IN CFG</span>
+        <span class="n">cfg</span> <span class="o">=</span> <span class="n">hydra</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">instantiate</span><span class="p">(</span><span class="n">cfg</span><span class="p">)</span>
 
 
-<span class="sd">        :keyword teacher_arch_params: (dict) Architecture H.P. e.g.: block, num_blocks, num_classes, etc for teacher</span>
-<span class="sd">            net. (deafult={})</span>
+        <span class="n">kwargs</span> <span class="o">=</span> <span class="n">parse_args</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="bp">cls</span><span class="o">.</span><span class="fm">__init__</span><span class="p">)</span>
 
 
-<span class="sd">        :keyword run_teacher_on_eval: (bool)- whether to run self.teacher at eval mode regardless of self.train(mode)</span>
+        <span class="n">trainer</span> <span class="o">=</span> <span class="n">KDTrainer</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
 
 
+        <span class="c1"># INSTANTIATE DATA LOADERS</span>
+        <span class="n">train_dataloader</span> <span class="o">=</span> <span class="n">dataloaders</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">train_dataloader</span><span class="p">,</span>
+                                           <span class="n">dataset_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_dataset_params</span><span class="p">,</span>
+                                           <span class="n">dataloader_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_dataloader_params</span><span class="p">)</span>
 
 
-<span class="sd">        &quot;&quot;&quot;</span>
-        <span class="n">kwargs</span><span class="o">.</span><span class="n">setdefault</span><span class="p">(</span><span class="s2">&quot;student_architecture&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
-        <span class="n">kwargs</span><span class="o">.</span><span class="n">setdefault</span><span class="p">(</span><span class="s2">&quot;teacher_architecture&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
-        <span class="n">kwargs</span><span class="o">.</span><span class="n">setdefault</span><span class="p">(</span><span class="s2">&quot;student_arch_params&quot;</span><span class="p">,</span> <span class="p">{})</span>
-        <span class="n">kwargs</span><span class="o">.</span><span class="n">setdefault</span><span class="p">(</span><span class="s2">&quot;teacher_arch_params&quot;</span><span class="p">,</span> <span class="p">{})</span>
-        <span class="n">kwargs</span><span class="o">.</span><span class="n">setdefault</span><span class="p">(</span><span class="s2">&quot;run_teacher_on_eval&quot;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
+        <span class="n">val_dataloader</span> <span class="o">=</span> <span class="n">dataloaders</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">val_dataloader</span><span class="p">,</span>
+                                         <span class="n">dataset_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_dataset_params</span><span class="p">,</span>
+                                         <span class="n">dataloader_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_dataloader_params</span><span class="p">)</span>
 
 
-        <span class="bp">self</span><span class="o">.</span><span class="n">_validate_args</span><span class="p">(</span><span class="n">arch_params</span><span class="p">,</span> <span class="n">architecture</span><span class="p">,</span> <span class="n">checkpoint_params</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
+        <span class="n">student</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">cfg</span><span class="o">.</span><span class="n">student_architecture</span><span class="p">,</span> <span class="n">arch_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">student_arch_params</span><span class="p">,</span>
+                             <span class="n">strict_load</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">student_checkpoint_params</span><span class="o">.</span><span class="n">strict_load</span><span class="p">,</span>
+                             <span class="n">pretrained_weights</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">student_checkpoint_params</span><span class="o">.</span><span class="n">pretrained_weights</span><span class="p">,</span>
+                             <span class="n">checkpoint_path</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">student_checkpoint_params</span><span class="o">.</span><span class="n">checkpoint_path</span><span class="p">,</span>
+                             <span class="n">load_backbone</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">student_checkpoint_params</span><span class="o">.</span><span class="n">load_backbone</span><span class="p">)</span>
 
 
-        <span class="bp">self</span><span class="o">.</span><span class="n">student_architecture</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;student_architecture&quot;</span><span class="p">)</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">teacher_architecture</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;teacher_architecture&quot;</span><span class="p">)</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">student_arch_params</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;student_arch_params&quot;</span><span class="p">)</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">teacher_arch_params</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;teacher_arch_params&quot;</span><span class="p">)</span>
+        <span class="n">teacher</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">cfg</span><span class="o">.</span><span class="n">teacher_architecture</span><span class="p">,</span> <span class="n">arch_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">teacher_arch_params</span><span class="p">,</span>
+                             <span class="n">strict_load</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">teacher_checkpoint_params</span><span class="o">.</span><span class="n">strict_load</span><span class="p">,</span>
+                             <span class="n">pretrained_weights</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">teacher_checkpoint_params</span><span class="o">.</span><span class="n">pretrained_weights</span><span class="p">,</span>
+                             <span class="n">checkpoint_path</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">teacher_checkpoint_params</span><span class="o">.</span><span class="n">checkpoint_path</span><span class="p">,</span>
+                             <span class="n">load_backbone</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">teacher_checkpoint_params</span><span class="o">.</span><span class="n">load_backbone</span><span class="p">)</span>
 
 
-        <span class="nb">super</span><span class="p">(</span><span class="n">KDModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">build_model</span><span class="p">(</span><span class="n">architecture</span><span class="o">=</span><span class="n">architecture</span><span class="p">,</span> <span class="n">arch_params</span><span class="o">=</span><span class="n">arch_params</span><span class="p">,</span>
-                                         <span class="n">checkpoint_params</span><span class="o">=</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
+        <span class="c1"># TRAIN</span>
+        <span class="n">trainer</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="n">training_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">training_hyperparams</span><span class="p">,</span> <span class="n">student</span><span class="o">=</span><span class="n">student</span><span class="p">,</span> <span class="n">teacher</span><span class="o">=</span><span class="n">teacher</span><span class=
+                      <span class="n">kd_architecture</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">architecture</span><span class="p">,</span> <span class="n">kd_arch_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">arch_params</span><span class="p">,</span>
+                      <span class="n">run_teacher_on_eval</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">run_teacher_on_eval</span><span class="p">,</span>
+                      <span class="n">train_loader</span><span class="o">=</span><span class="n">train_dataloader</span><span class="p">,</span> <span class="n">valid_loader</span><span class="o">=</span><span class="n">val_dataloader</span><span class="p">)</span></div>
 
 
     <span class="k">def</span> <span class="nf">_validate_args</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">arch_params</span><span class="p">,</span> <span class="n">architecture</span><span class="p">,</span> <span class="n">checkpoint_params</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">_validate_args</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">arch_params</span><span class="p">,</span> <span class="n">architecture</span><span class="p">,</span> <span class="n">checkpoint_params</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
         <span class="n">student_architecture</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">kwargs</span><span class="p">,</span> <span class="s2">&quot;student_architecture&quot;</span><span class="p">)</span>
         <span class="n">student_architecture</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">kwargs</span><span class="p">,</span> <span class="s2">&quot;student_architecture&quot;</span><span class="p">)</span>
@@ -215,7 +201,8 @@
         <span class="n">load_kd_model_checkpoint</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s2">&quot;load_checkpoint&quot;</span><span class="p">)</span>
         <span class="n">load_kd_model_checkpoint</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s2">&quot;load_checkpoint&quot;</span><span class="p">)</span>
 
 
         <span class="c1"># CHECK THAT TEACHER NETWORK HOLDS KNOWLEDGE FOR THE STUDENT TO LEARN FROM OR THAT WE ARE LOADING AN ENTIRE KD</span>
         <span class="c1"># CHECK THAT TEACHER NETWORK HOLDS KNOWLEDGE FOR THE STUDENT TO LEARN FROM OR THAT WE ARE LOADING AN ENTIRE KD</span>
-        <span class="k">if</span> <span class="ow">not</span> <span class="p">(</span><span class="n">teacher_pretrained_weights</span> <span class="ow">or</span> <span class="n">teacher_checkpoint_path</span> <span class="ow">or</span> <span class="n">load_kd_model_checkpoint</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">teacher_architecture</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span cla
+        <span class="k">if</span> <span class="ow">not</span> <span class="p">(</span><span class="n">teacher_pretrained_weights</span> <span class="ow">or</span> <span class="n">teacher_checkpoint_path</span> <span class="ow">or</span> <span class="n">load_kd_model_checkpoint</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span>
+                <span class="n">teacher_architecture</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">)):</span>
             <span class="k">raise</span> <span class="n">TeacherKnowledgeException</span><span class="p">()</span>
             <span class="k">raise</span> <span class="n">TeacherKnowledgeException</span><span class="p">()</span>
 
 
     <span class="k">def</span> <span class="nf">_validate_num_classes</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">student_arch_params</span><span class="p">,</span> <span class="n">teacher_arch_params</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">_validate_num_classes</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">student_arch_params</span><span class="p">,</span> <span class="n">teacher_arch_params</span><span class="p">):</span>
@@ -277,6 +264,9 @@
 
 
         <span class="n">run_teacher_on_eval</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">kwargs</span><span class="p">,</span> <span class="s2">&quot;run_teacher_on_eval&quot;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
         <span class="n">run_teacher_on_eval</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">kwargs</span><span class="p">,</span> <span class="s2">&quot;run_teacher_on_eval&quot;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
 
 
+        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_instantiate_kd_net</span><span class="p">(</span><span class="n">arch_params</span><span class="p">,</span> <span class="n">architecture</span><span class="p">,</span> <span class="n">run_teacher_on_eval</span><span class="p">,</span> <span class="n">student</span><span class="p">,</span> <span class="n">teacher</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_instantiate_kd_net</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">arch_params</span><span class="p">,</span> <span class="n">architecture</span><span class="p">,</span> <span class="n">run_teacher_on_eval</span><span class="p">,</span> <span class="n">student</span><span class="p">,</span> <span class="n">teacher</span><span class="p">):</span>
         <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">architecture</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
         <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">architecture</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
             <span class="n">architecture_cls</span> <span class="o">=</span> <span class="n">KD_ARCHITECTURES</span><span class="p">[</span><span class="n">architecture</span><span class="p">]</span>
             <span class="n">architecture_cls</span> <span class="o">=</span> <span class="n">KD_ARCHITECTURES</span><span class="p">[</span><span class="n">architecture</span><span class="p">]</span>
             <span class="n">net</span> <span class="o">=</span> <span class="n">architecture_cls</span><span class="p">(</span><span class="n">arch_params</span><span class="o">=</span><span class="n">arch_params</span><span class="p">,</span> <span class="n">student</span><span class="o">=</span><span class="n">student</span><span class="p">,</span> <span class="n">teacher</span><span class="o">=</span><span class="n">teacher</span><span class="p">,</span>
             <span class="n">net</span> <span class="o">=</span> <span class="n">architecture_cls</span><span class="p">(</span><span class="n">arch_params</span><span class="o">=</span><span class="n">arch_params</span><span class="p">,</span> <span class="n">student</span><span class="o">=</span><span class="n">student</span><span class="p">,</span> <span class="n">teacher</span><span class="o">=</span><span class="n">teacher</span><span class="p">,</span>
@@ -286,13 +276,12 @@
                                <span class="n">run_teacher_on_eval</span><span class="o">=</span><span class="n">run_teacher_on_eval</span><span class="p">)</span>
                                <span class="n">run_teacher_on_eval</span><span class="o">=</span><span class="n">run_teacher_on_eval</span><span class="p">)</span>
         <span class="k">else</span><span class="p">:</span>
         <span class="k">else</span><span class="p">:</span>
             <span class="n">net</span> <span class="o">=</span> <span class="n">architecture</span>
             <span class="n">net</span> <span class="o">=</span> <span class="n">architecture</span>
-
         <span class="k">return</span> <span class="n">net</span>
         <span class="k">return</span> <span class="n">net</span>
 
 
     <span class="k">def</span> <span class="nf">_load_checkpoint_to_model</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">_load_checkpoint_to_model</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        Initializes teacher weights with teacher_checkpoint_path if needed, then handles checkpoint loading for</span>
 <span class="sd">        Initializes teacher weights with teacher_checkpoint_path if needed, then handles checkpoint loading for</span>
-<span class="sd">         the entire KD network following the same logic as in SgModel.</span>
+<span class="sd">         the entire KD network following the same logic as in Trainer.</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
         <span class="n">teacher_checkpoint_path</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s2">&quot;teacher_checkpoint_path&quot;</span><span class="p">)</span>
         <span class="n">teacher_checkpoint_path</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s2">&quot;teacher_checkpoint_path&quot;</span><span class="p">)</span>
         <span class="n">teacher_net</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">teacher</span>
         <span class="n">teacher_net</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">teacher</span>
@@ -315,7 +304,7 @@
                                      <span class="n">load_weights_only</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
                                      <span class="n">load_weights_only</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
                                      <span class="n">load_ema_as_net</span><span class="o">=</span><span class="n">load_teachers_ema</span><span class="p">)</span>
                                      <span class="n">load_ema_as_net</span><span class="o">=</span><span class="n">load_teachers_ema</span><span class="p">)</span>
 
 
-        <span class="nb">super</span><span class="p">(</span><span class="n">KDModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">_load_checkpoint_to_model</span><span class="p">()</span>
+        <span class="nb">super</span><span class="p">(</span><span class="n">KDTrainer</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">_load_checkpoint_to_model</span><span class="p">()</span>
 
 
     <span class="k">def</span> <span class="nf">_add_metrics_update_callback</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">phase</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">_add_metrics_update_callback</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">phase</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
@@ -337,7 +326,8 @@
                                    <span class="p">})</span>
                                    <span class="p">})</span>
         <span class="k">return</span> <span class="n">hyper_param_config</span>
         <span class="k">return</span> <span class="n">hyper_param_config</span>
 
 
-    <span class="k">def</span> <span class="nf">_instantiate_ema_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">decay</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.9999</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">15</span><span class="p">,</span> <span class="n">exp_acti
+    <span class="k">def</span> <span class="nf">_instantiate_ema_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">decay</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.9999</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">15</span><span class="p">,</span>
+                               <span class="n">exp_activation</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">KDModelEMA</span><span class="p">:</span>
         <span class="sd">&quot;&quot;&quot;Instantiate KD ema model for KDModule.</span>
         <span class="sd">&quot;&quot;&quot;Instantiate KD ema model for KDModule.</span>
 
 
 <span class="sd">        If the model is of class KDModule, the instance will be adapted to work on knowledge distillation.</span>
 <span class="sd">        If the model is of class KDModule, the instance will be adapted to work on knowledge distillation.</span>
@@ -360,7 +350,37 @@
             <span class="n">best_net</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">WrappedModel</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">student</span><span class="p">)</span>
             <span class="n">best_net</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">WrappedModel</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">student</span><span class="p">)</span>
 
 
         <span class="n">state</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">best_net</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span>
         <span class="n">state</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">best_net</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ckpt_best_name</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">global_step</span><span class="o
+        <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ckpt_best_name</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">global_step</span><span class="o
+
+<div class="viewcode-block" id="KDTrainer.train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.KDTrainer.train">[docs]</a>    <span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">KDModule</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">training_
+              <span class="n">teacher</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">kd_architecture</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">KDModule</span><span class="o">.</span><span class="vm">__class__</span><span class="p">,<
+              <span class="n">kd_arch_params</span><span class="p">:</span> <span class="nb">dict</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(),</span> <span class="n">run_teacher_on_eval</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+              <span class="n">valid_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
+        <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">        Trains the student network (wrapped in KDModule network).</span>
+
+<span class="sd">        :param model: KDModule, network to train. When none is given will initialize KDModule according to kd_architecture,</span>
+<span class="sd">            student and teacher (default=None)</span>
+<span class="sd">        :param training_params: dict, Same as in Trainer.train()</span>
+<span class="sd">        :param student: SgModule - the student trainer</span>
+<span class="sd">        :param teacher: torch.nn.Module- the teacher trainer</span>
+<span class="sd">        :param kd_architecture: KDModule architecture to use, currently only &#39;kd_module&#39; is supported (default=&#39;kd_module&#39;).</span>
+<span class="sd">        :param kd_arch_params: architecture params to pas to kd_architecture constructor.</span>
+<span class="sd">        :param run_teacher_on_eval: bool- whether to run self.teacher at eval mode regardless of self.train(mode)</span>
+<span class="sd">        :param train_loader: Dataloader for train set.</span>
+<span class="sd">        :param valid_loader: Dataloader for validation.</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+        <span class="n">kd_net</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="ow">or</span> <span class="n">model</span>
+        <span class="k">if</span> <span class="n">kd_net</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="k">if</span> <span class="n">student</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="n">teacher</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Must pass student and teacher models or net (KDModule).&quot;</span><span class="p">)</span>
+            <span class="n">kd_net</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_instantiate_kd_net</span><span class="p">(</span><span class="n">arch_params</span><span class="o">=</span><span class="n">HpmStruct</span><span class="p">(</span><span class="o">**</span><span class="n">kd_arch_params</span><span class="p">),</span>
+                                              <span class="n">architecture</span><span class="o">=</span><span class="n">kd_architecture</span><span class="p">,</span>
+                                              <span class="n">run_teacher_on_eval</span><span class="o">=</span><span class="n">run_teacher_on_eval</span><span class="p">,</span>
+                                              <span class="n">student</span><span class="o">=</span><span class="n">student</span><span class="p">,</span>
+                                              <span class="n">teacher</span><span class="o">=</span><span class="n">teacher</span><span class="p">)</span>
+        <span class="nb">super</span><span class="p">(</span><span class="n">KDTrainer</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="n">model</span><span class="o">=</span><span class="n">kd_net</span><span class="p">,</span> <span class="n">training_params</span><span class="o">=</span><span class="n">training_params</span><span class="p">,</span>
+                                     <span class="n">train_loader</span><span class="o">=</span><span class="n">train_loader</span><span class="p">,</span> <span class="n">valid_loader</span><span class="o">=</span><span class="n">valid_loader</span><span class="p">)</span></div></div>
 </pre></div>
 </pre></div>
 
 
            </div>
            </div>
@@ -390,4 +410,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.legacy.utils &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.legacy.utils</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.legacy.utils</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">os</span>
  84. <span class="kn">import</span> <span class="nn">torch</span>
  85. <span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
  86. <span class="kn">import</span> <span class="nn">torch.nn.init</span> <span class="k">as</span> <span class="nn">init</span>
  87. <div class="viewcode-block" id="prefetch_dataset"><a class="viewcode-back" href="../../../../super_gradients.training.legacy.html#super_gradients.training.legacy.utils.prefetch_dataset">[docs]</a><span class="k">def</span> <span class="nf">prefetch_dataset</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">half</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
  88. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataset</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
  89. <span class="n">tensors</span> <span class="o">=</span> <span class="n">dataset</span>
  90. <span class="k">else</span><span class="p">:</span>
  91. <span class="n">dataloader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span>
  92. <span class="n">dataset</span><span class="p">,</span>
  93. <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span>
  94. <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">drop_last</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  95. <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span> <span class="n">pin_memory</span><span class="o">=</span><span class="kc">False</span>
  96. <span class="p">)</span>
  97. <span class="n">tensors</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">dataloader</span><span class="p">]</span>
  98. <span class="n">tensors</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">tensors</span><span class="p">)]</span>
  99. <span class="k">if</span> <span class="n">device</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  100. <span class="n">tensors</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">]</span>
  101. <span class="k">if</span> <span class="n">half</span><span class="p">:</span>
  102. <span class="n">tensors</span> <span class="o">=</span> <span class="p">[</span><span class="n">t</span><span class="o">.</span><span class="n">half</span><span class="p">()</span> <span class="k">if</span> <span class="n">t</span><span class="o">.</span><span class="n">is_floating_point</span><span class="p">()</span> <span class="k">else</span> <span class="n">t</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">]</span>
  103. <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">TensorDataset</span><span class="p">(</span><span class="o">*</span><span class="n">tensors</span><span class="p">)</span></div>
  104. <div class="viewcode-block" id="PrefetchDataLoader"><a class="viewcode-back" href="../../../../super_gradients.training.legacy.html#super_gradients.training.legacy.utils.PrefetchDataLoader">[docs]</a><span class="k">class</span> <span class="nc">PrefetchDataLoader</span><span class="p">:</span>
  105. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataloader</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">half</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
  106. <span class="bp">self</span><span class="o">.</span><span class="n">loader</span> <span class="o">=</span> <span class="n">dataloader</span>
  107. <span class="bp">self</span><span class="o">.</span><span class="n">iter</span> <span class="o">=</span> <span class="kc">None</span>
  108. <span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span>
  109. <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">float16</span> <span class="k">if</span> <span class="n">half</span> <span class="k">else</span> <span class="n">torch</span><span class="o">.</span><span class="n">float32</span>
  110. <span class="bp">self</span><span class="o">.</span><span class="n">stream</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Stream</span><span class="p">()</span>
  111. <span class="bp">self</span><span class="o">.</span><span class="n">next_data</span> <span class="o">=</span> <span class="kc">None</span>
  112. <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  113. <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loader</span><span class="p">)</span>
  114. <div class="viewcode-block" id="PrefetchDataLoader.async_prefech"><a class="viewcode-back" href="../../../../super_gradients.training.legacy.html#super_gradients.training.legacy.utils.PrefetchDataLoader.async_prefech">[docs]</a> <span class="k">def</span> <span class="nf">async_prefech</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  115. <span class="k">try</span><span class="p">:</span>
  116. <span class="bp">self</span><span class="o">.</span><span class="n">next_data</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">iter</span><span class="p">)</span>
  117. <span class="k">except</span> <span class="ne">StopIteration</span><span class="p">:</span>
  118. <span class="bp">self</span><span class="o">.</span><span class="n">next_data</span> <span class="o">=</span> <span class="kc">None</span>
  119. <span class="k">return</span>
  120. <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">):</span>
  121. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">next_data</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
  122. <span class="bp">self</span><span class="o">.</span><span class="n">next_data</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  123. <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">next_data</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)):</span>
  124. <span class="bp">self</span><span class="o">.</span><span class="n">next_data</span> <span class="o">=</span> <span class="p">[</span>
  125. <span class="n">t</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="k">if</span> <span class="n">t</span><span class="o">.</span><span class="n">is_floating_point</span><span class="p">()</span> <span class="k">else</span> <span class="n">t</span><span class="o">.</span><span class="n">to</span><span class="p">(</span>
  126. <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_data</span>
  127. <span class="p">]</span></div>
  128. <span class="k">def</span> <span class="fm">__iter__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  129. <span class="bp">self</span><span class="o">.</span><span class="n">iter</span> <span class="o">=</span> <span class="nb">iter</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loader</span><span class="p">)</span>
  130. <span class="bp">self</span><span class="o">.</span><span class="n">async_prefech</span><span class="p">()</span>
  131. <span class="k">while</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_data</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  132. <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">current_stream</span><span class="p">()</span><span class="o">.</span><span class="n">wait_stream</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stream</span><span class="p">)</span>
  133. <span class="n">data</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_data</span>
  134. <span class="bp">self</span><span class="o">.</span><span class="n">async_prefech</span><span class="p">()</span>
  135. <span class="k">yield</span> <span class="n">data</span></div>
  136. <div class="viewcode-block" id="init_params"><a class="viewcode-back" href="../../../../super_gradients.training.legacy.html#super_gradients.training.legacy.utils.init_params">[docs]</a><span class="k">def</span> <span class="nf">init_params</span><span class="p">(</span><span class="n">net</span><span class="p">):</span>
  137. <span class="sd">&quot;&quot;&quot;Init layer parameters.&quot;&quot;&quot;</span>
  138. <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="n">net</span><span class="o">.</span><span class="n">modules</span><span class="p">():</span>
  139. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">):</span>
  140. <span class="n">init</span><span class="o">.</span><span class="n">kaiming_normal</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s1">&#39;fan_out&#39;</span><span class="p">)</span>
  141. <span class="c1"># if m.bias:</span>
  142. <span class="c1"># init.constant(m.bias, -5)</span>
  143. <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">):</span>
  144. <span class="n">init</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
  145. <span class="n">init</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
  146. <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">):</span>
  147. <span class="n">init</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">)</span>
  148. <span class="k">if</span> <span class="n">m</span><span class="o">.</span><span class="n">bias</span><span class="p">:</span>
  149. <span class="n">init</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span></div>
  150. <div class="viewcode-block" id="format_time"><a class="viewcode-back" href="../../../../super_gradients.training.legacy.html#super_gradients.training.legacy.utils.format_time">[docs]</a><span class="k">def</span> <span class="nf">format_time</span><span class="p">(</span><span class="n">seconds</span><span class="p">):</span>
  151. <span class="n">days</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">seconds</span> <span class="o">/</span> <span class="mi">3600</span> <span class="o">/</span> <span class="mi">24</span><span class="p">)</span>
  152. <span class="n">seconds</span> <span class="o">=</span> <span class="n">seconds</span> <span class="o">-</span> <span class="n">days</span> <span class="o">*</span> <span class="mi">3600</span> <span class="o">*</span> <span class="mi">24</span>
  153. <span class="n">hours</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">seconds</span> <span class="o">/</span> <span class="mi">3600</span><span class="p">)</span>
  154. <span class="n">seconds</span> <span class="o">=</span> <span class="n">seconds</span> <span class="o">-</span> <span class="n">hours</span> <span class="o">*</span> <span class="mi">3600</span>
  155. <span class="n">minutes</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">seconds</span> <span class="o">/</span> <span class="mi">60</span><span class="p">)</span>
  156. <span class="n">seconds</span> <span class="o">=</span> <span class="n">seconds</span> <span class="o">-</span> <span class="n">minutes</span> <span class="o">*</span> <span class="mi">60</span>
  157. <span class="n">secondsf</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">seconds</span><span class="p">)</span>
  158. <span class="n">seconds</span> <span class="o">=</span> <span class="n">seconds</span> <span class="o">-</span> <span class="n">secondsf</span>
  159. <span class="n">millis</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">seconds</span> <span class="o">*</span> <span class="mi">1000</span><span class="p">)</span>
  160. <span class="n">f</span> <span class="o">=</span> <span class="s1">&#39;&#39;</span>
  161. <span class="n">i</span> <span class="o">=</span> <span class="mi">1</span>
  162. <span class="k">if</span> <span class="n">days</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
  163. <span class="n">f</span> <span class="o">+=</span> <span class="nb">str</span><span class="p">(</span><span class="n">days</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;D&#39;</span>
  164. <span class="n">i</span> <span class="o">+=</span> <span class="mi">1</span>
  165. <span class="k">if</span> <span class="n">hours</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">i</span> <span class="o">&lt;=</span> <span class="mi">2</span><span class="p">:</span>
  166. <span class="n">f</span> <span class="o">+=</span> <span class="nb">str</span><span class="p">(</span><span class="n">hours</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;h&#39;</span>
  167. <span class="n">i</span> <span class="o">+=</span> <span class="mi">1</span>
  168. <span class="k">if</span> <span class="n">minutes</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">i</span> <span class="o">&lt;=</span> <span class="mi">2</span><span class="p">:</span>
  169. <span class="n">f</span> <span class="o">+=</span> <span class="nb">str</span><span class="p">(</span><span class="n">minutes</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;m&#39;</span>
  170. <span class="n">i</span> <span class="o">+=</span> <span class="mi">1</span>
  171. <span class="k">if</span> <span class="n">secondsf</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">i</span> <span class="o">&lt;=</span> <span class="mi">2</span><span class="p">:</span>
  172. <span class="n">f</span> <span class="o">+=</span> <span class="nb">str</span><span class="p">(</span><span class="n">secondsf</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;s&#39;</span>
  173. <span class="n">i</span> <span class="o">+=</span> <span class="mi">1</span>
  174. <span class="k">if</span> <span class="n">millis</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">i</span> <span class="o">&lt;=</span> <span class="mi">2</span><span class="p">:</span>
  175. <span class="n">f</span> <span class="o">+=</span> <span class="nb">str</span><span class="p">(</span><span class="n">millis</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;ms&#39;</span>
  176. <span class="n">i</span> <span class="o">+=</span> <span class="mi">1</span>
  177. <span class="k">if</span> <span class="n">f</span> <span class="o">==</span> <span class="s1">&#39;&#39;</span><span class="p">:</span>
  178. <span class="n">f</span> <span class="o">=</span> <span class="s1">&#39;0ms&#39;</span>
  179. <span class="k">return</span> <span class="n">f</span></div>
  180. <div class="viewcode-block" id="is_better"><a class="viewcode-back" href="../../../../super_gradients.training.legacy.html#super_gradients.training.legacy.utils.is_better">[docs]</a><span class="k">def</span> <span class="nf">is_better</span><span class="p">(</span><span class="n">new_metric</span><span class="p">,</span> <span class="n">current_best_metric</span><span class="p">,</span> <span class="n">metric_to_watch</span><span class="o">=</span><span class="s1">&#39;acc&#39;</span><span class="p">):</span>
  181. <span class="sd">&quot;&quot;&quot;</span>
  182. <span class="sd"> Determines which of the two metrics is better, the higher if watching acc or lower when watching loss</span>
  183. <span class="sd"> :param new_metric: the new metric</span>
  184. <span class="sd"> :param current_best_metric: the compared to metric</span>
  185. <span class="sd"> :param metric_to_watch: acc or loss</span>
  186. <span class="sd"> :return: bool, True if new metric is better than current</span>
  187. <span class="sd"> &quot;&quot;&quot;</span>
  188. <span class="k">return</span> <span class="n">metric_to_watch</span> <span class="o">==</span> <span class="s1">&#39;acc&#39;</span> <span class="ow">and</span> <span class="n">new_metric</span> <span class="o">&gt;</span> <span class="n">current_best_metric</span> <span class="ow">or</span> <span class="p">(</span><span class="n">metric_to_watch</span> <span class="o">==</span> <span class="s1">&#39;loss&#39;</span> <span class="ow">and</span> <span class="n">current_best_metric</span> <span class="o">&gt;</span> <span class="n">new_metric</span><span class="p">)</span></div>
  189. <div class="viewcode-block" id="makedirs_if_not_exists"><a class="viewcode-back" href="../../../../super_gradients.training.legacy.html#super_gradients.training.legacy.utils.makedirs_if_not_exists">[docs]</a><span class="k">def</span> <span class="nf">makedirs_if_not_exists</span><span class="p">(</span><span class="n">dir_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
  190. <span class="sd">&quot;&quot;&quot;</span>
  191. <span class="sd"> make new directory in dir_path if it doesn&#39;t exists</span>
  192. <span class="sd"> :param dir_path - full path of directory</span>
  193. <span class="sd"> &quot;&quot;&quot;</span>
  194. <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">dir_path</span><span class="p">):</span>
  195. <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">dir_path</span><span class="p">)</span></div>
  196. </pre></div>
  197. </div>
  198. </div>
  199. <footer>
  200. <hr/>
  201. <div role="contentinfo">
  202. <p>&#169; Copyright 2021, SuperGradients team.</p>
  203. </div>
  204. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  205. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  206. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  207. </footer>
  208. </div>
  209. </div>
  210. </section>
  211. </div>
  212. <script>
  213. jQuery(function () {
  214. SphinxRtdTheme.Navigation.enable(true);
  215. });
  216. </script>
  217. </body>
  218. </html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.losses.bce_dice_loss &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.losses.bce_dice_loss &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -92,7 +94,7 @@
 <span class="kn">from</span> <span class="nn">super_gradients.training.losses.dice_loss</span> <span class="kn">import</span> <span class="n">BinaryDiceLoss</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.losses.dice_loss</span> <span class="kn">import</span> <span class="n">BinaryDiceLoss</span>
 
 
 
 
-<div class="viewcode-block" id="BCEDiceLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.BCEDiceLoss">[docs]</a><span class="k">class</span> <span class="nc">BCEDiceLoss</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
+<div class="viewcode-block" id="BCEDiceLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.BCEDiceLoss">[docs]</a><span class="k">class</span> <span class="nc">BCEDiceLoss</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Binary Cross Entropy + Dice Loss</span>
 <span class="sd">    Binary Cross Entropy + Dice Loss</span>
 
 
@@ -108,7 +110,7 @@
         <span class="bp">self</span><span class="o">.</span><span class="n">bce</span> <span class="o">=</span> <span class="n">BCE</span><span class="p">()</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">bce</span> <span class="o">=</span> <span class="n">BCE</span><span class="p">()</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">dice</span> <span class="o">=</span> <span class="n">BinaryDiceLoss</span><span class="p">(</span><span class="n">apply_sigmoid</span><span class="o">=</span><span class="n">logits</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">dice</span> <span class="o">=</span> <span class="n">BinaryDiceLoss</span><span class="p">(</span><span class="n">apply_sigmoid</span><span class="o">=</span><span class="n">logits</span><span class="p">)</span>
 
 
-<div class="viewcode-block" id="BCEDiceLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.BCEDiceLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <spa
+<div class="viewcode-block" id="BCEDiceLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.BCEDiceLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 
 
 <span class="sd">        @param input: Network&#39;s raw output shaped (N,1,H,W)</span>
 <span class="sd">        @param input: Network&#39;s raw output shaped (N,1,H,W)</span>
@@ -145,4 +147,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.losses.ddrnet_loss &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.losses.ddrnet_loss</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.losses.ddrnet_loss</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">torch</span>
  84. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span>
  85. <span class="kn">from</span> <span class="nn">super_gradients.training.losses.ohem_ce_loss</span> <span class="kn">import</span> <span class="n">OhemCELoss</span>
  86. <div class="viewcode-block" id="DDRNetLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ddrnet_loss.DDRNetLoss">[docs]</a><span class="k">class</span> <span class="nc">DDRNetLoss</span><span class="p">(</span><span class="n">OhemCELoss</span><span class="p">):</span>
  87. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
  88. <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.7</span><span class="p">,</span>
  89. <span class="n">ohem_percentage</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
  90. <span class="n">weights</span><span class="p">:</span> <span class="nb">list</span> <span class="o">=</span> <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">],</span>
  91. <span class="n">ignore_label</span><span class="o">=</span><span class="mi">255</span><span class="p">,</span>
  92. <span class="n">num_pixels_exclude_ignored</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
  93. <span class="sd">&quot;&quot;&quot;</span>
  94. <span class="sd"> This loss is an extension of the Ohem (Online Hard Example Mining Cross Entropy) Loss.</span>
  95. <span class="sd"> as define in paper:</span>
  96. <span class="sd"> Accurate Semantic Segmentation of Road Scenes ( https://arxiv.org/pdf/2101.06085.pdf )</span>
  97. <span class="sd"> :param threshold: threshold to th hard example mining algorithm</span>
  98. <span class="sd"> :param ohem_percentage: minimum percentage of total pixels for the hard example mining algorithm</span>
  99. <span class="sd"> (taking only the largest) losses</span>
  100. <span class="sd"> :param weights: weights per each input of the loss. This loss supports a multi output (like in DDRNet with</span>
  101. <span class="sd"> an auxiliary head). the losses of each head can be weighted.</span>
  102. <span class="sd"> :param ignore_label: targets label to be ignored</span>
  103. <span class="sd"> :param num_pixels_exclude_ignored: whether to exclude ignore pixels when calculating the mining percentage.</span>
  104. <span class="sd"> see OhemCELoss doc for more details.</span>
  105. <span class="sd"> &quot;&quot;&quot;</span>
  106. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">,</span> <span class="n">mining_percent</span><span class="o">=</span><span class="n">ohem_percentage</span><span class="p">,</span> <span class="n">ignore_lb</span><span class="o">=</span><span class="n">ignore_label</span><span class="p">,</span>
  107. <span class="n">num_pixels_exclude_ignored</span><span class="o">=</span><span class="n">num_pixels_exclude_ignored</span><span class="p">)</span>
  108. <span class="bp">self</span><span class="o">.</span><span class="n">weights</span> <span class="o">=</span> <span class="n">weights</span>
  109. <div class="viewcode-block" id="DDRNetLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ddrnet_loss.DDRNetLoss.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions_list</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span>
  110. <span class="n">targets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
  111. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">predictions_list</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
  112. <span class="n">predictions_list</span> <span class="o">=</span> <span class="p">(</span><span class="n">predictions_list</span><span class="p">,)</span>
  113. <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">predictions_list</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weights</span><span class="p">),</span> <span class="s2">&quot;num of prediction must be the same as num of loss weights&quot;</span>
  114. <span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
  115. <span class="n">unweighted_losses</span> <span class="o">=</span> <span class="p">[]</span>
  116. <span class="k">for</span> <span class="n">predictions</span><span class="p">,</span> <span class="n">weight</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">predictions_list</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weights</span><span class="p">):</span>
  117. <span class="n">unweighted_loss</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">predictions</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span>
  118. <span class="n">unweighted_losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">unweighted_loss</span><span class="p">)</span>
  119. <span class="n">losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">unweighted_loss</span> <span class="o">*</span> <span class="n">weight</span><span class="p">)</span>
  120. <span class="n">total_loss</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">losses</span><span class="p">)</span>
  121. <span class="n">unweighted_losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">total_loss</span><span class="p">)</span>
  122. <span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">unweighted_losses</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></div></div>
  123. </pre></div>
  124. </div>
  125. </div>
  126. <footer>
  127. <hr/>
  128. <div role="contentinfo">
  129. <p>&#169; Copyright 2021, SuperGradients team.</p>
  130. </div>
  131. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  132. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  133. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  134. </footer>
  135. </div>
  136. </div>
  137. </section>
  138. </div>
  139. <script>
  140. jQuery(function () {
  141. SphinxRtdTheme.Navigation.enable(true);
  142. });
  143. </script>
  144. </body>
  145. </html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.losses.dice_ce_edge_loss &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.losses.dice_ce_edge_loss &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -95,7 +97,7 @@
 <span class="kn">from</span> <span class="nn">super_gradients.training.losses.mask_loss</span> <span class="kn">import</span> <span class="n">MaskAttentionLoss</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.losses.mask_loss</span> <span class="kn">import</span> <span class="n">MaskAttentionLoss</span>
 
 
 
 
-<div class="viewcode-block" id="DiceCEEdgeLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.DiceCEEdgeLoss">[docs]</a><span class="k">class</span> <span class="nc">DiceCEEdgeLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
+<div class="viewcode-block" id="DiceCEEdgeLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.DiceCEEdgeLoss">[docs]</a><span class="k">class</span> <span class="nc">DiceCEEdgeLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
                  <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
                  <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
                  <span class="n">num_aux_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span>
                  <span class="n">num_aux_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span>
@@ -151,7 +153,22 @@
         <span class="p">)</span>
         <span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">dice_loss</span> <span class="o">=</span> <span class="n">DiceLoss</span><span class="p">(</span><span class="n">apply_softmax</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">dice_loss</span> <span class="o">=</span> <span class="n">DiceLoss</span><span class="p">(</span><span class="n">apply_softmax</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">)</span>
 
 
-<div class="viewcode-block" id="DiceCEEdgeLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.DiceCEEdgeLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><
+    <span class="nd">@property</span>
+    <span class="k">def</span> <span class="nf">component_names</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+        <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">        Component names for logging during training.</span>
+<span class="sd">        These correspond to 2nd item in the tuple returned in self.forward(...).</span>
+<span class="sd">        See super_gradients.Trainer.train() docs for more info.</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+        <span class="n">names</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;main_loss&quot;</span><span class="p">]</span>
+        <span class="c1"># Append aux losses names</span>
+        <span class="n">names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s2">&quot;aux_loss</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">&quot;</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_aux_heads</span><span class="p">)]</span
+        <span class="c1"># Append detail losses names</span>
+        <span class="n">names</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s2">&quot;detail_loss</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">&quot;</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_detail_heads</span><span class="p">)]
+        <span class="n">names</span> <span class="o">+=</span> <span class="p">[</span><span class="s2">&quot;loss&quot;</span><span class="p">]</span>
+        <span class="k">return</span> <span class="n">names</span>
+
+<div class="viewcode-block" id="DiceCEEdgeLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.DiceCEEdgeLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span cl
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        :param preds: Model output predictions, must be in the followed format:</span>
 <span class="sd">        :param preds: Model output predictions, must be in the followed format:</span>
 <span class="sd">         [Main-feats, Aux-feats[0], ..., Aux-feats[num_auxs-1], Detail-feats[0], ..., Detail-feats[num_details-1]</span>
 <span class="sd">         [Main-feats, Aux-feats[0], ..., Aux-feats[num_auxs-1], Detail-feats[0], ..., Detail-feats[num_details-1]</span>
@@ -214,4 +231,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.losses.focal_loss &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.losses.focal_loss &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -91,7 +93,7 @@
 <span class="kn">from</span> <span class="nn">torch.nn.modules.loss</span> <span class="kn">import</span> <span class="n">_Loss</span>
 <span class="kn">from</span> <span class="nn">torch.nn.modules.loss</span> <span class="kn">import</span> <span class="n">_Loss</span>
 
 
 
 
-<div class="viewcode-block" id="FocalLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.FocalLoss">[docs]</a><span class="k">class</span> <span class="nc">FocalLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
+<div class="viewcode-block" id="FocalLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.FocalLoss">[docs]</a><span class="k">class</span> <span class="nc">FocalLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)&quot;&quot;&quot;</span>
 
 
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">loss_fcn</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">,</span> <span class="n">gamma</span><span class="o">=</span><span class="mf">1.5</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.25</span><span
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">loss_fcn</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">,</span> <span class="n">gamma</span><span class="o">=</span><span class="mf">1.5</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.25</span><span
@@ -102,7 +104,7 @@
         <span class="bp">self</span><span class="o">.</span><span class="n">reduction</span> <span class="o">=</span> <span class="n">loss_fcn</span><span class="o">.</span><span class="n">reduction</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">reduction</span> <span class="o">=</span> <span class="n">loss_fcn</span><span class="o">.</span><span class="n">reduction</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">loss_fcn</span><span class="o">.</span><span class="n">reduction</span> <span class="o">=</span> <span class="s1">&#39;none&#39;</span>  <span class="c1"># required to apply FocalLoss to each element</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">loss_fcn</span><span class="o">.</span><span class="n">reduction</span> <span class="o">=</span> <span class="s1">&#39;none&#39;</span>  <span class="c1"># required to apply FocalLoss to each element</span>
 
 
-<div class="viewcode-block" id="FocalLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.FocalLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pred</span><span class="p">,</span> <span class="n">true</span><span class="p">):</span>
+<div class="viewcode-block" id="FocalLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.FocalLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pred</span><span class="p">,</span> <span class="n">true</span><span class="p">):</span>
         <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fcn</span><span class="p">(</span><span class="n">pred</span><span class="p">,</span> <span class="n">true</span><span class="p">)</span>
         <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fcn</span><span class="p">(</span><span class="n">pred</span><span class="p">,</span> <span class="n">true</span><span class="p">)</span>
 
 
         <span class="n">pred_prob</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">pred</span><span class="p">)</span>  <span class="c1"># prob from logits</span>
         <span class="n">pred_prob</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">pred</span><span class="p">)</span>  <span class="c1"># prob from logits</span>
@@ -146,4 +148,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.losses.kd_losses &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.losses.kd_losses &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -100,7 +102,7 @@
                                                 <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">teacher_output</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span>
                                                 <span class="n">torch</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">teacher_output</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span>
 
 
 
 
-<div class="viewcode-block" id="KDLogitsLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.KDLogitsLoss">[docs]</a><span class="k">class</span> <span class="nc">KDLogitsLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
+<div class="viewcode-block" id="KDLogitsLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.KDLogitsLoss">[docs]</a><span class="k">class</span> <span class="nc">KDLogitsLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot; Knowledge distillation loss, wraps the task loss and distillation loss &quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot; Knowledge distillation loss, wraps the task loss and distillation loss &quot;&quot;&quot;</span>
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">task_loss_fn</span><span class="p">:</span> <span class="n">_Loss</span><span class="p">,</span> <span class="n">distillation_loss_fn</span><span class="p">:</span> <span class="n">_Loss</span> <span class="o">=</span> <span class="n">KDklDivLoss</span><span class="p">(),</span> <span class="n">distillation_loss_coeff</span><span class="p">:
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">task_loss_fn</span><span class="p">:</span> <span class="n">_Loss</span><span class="p">,</span> <span class="n">distillation_loss_fn</span><span class="p">:</span> <span class="n">_Loss</span> <span class="o">=</span> <span class="n">KDklDivLoss</span><span class="p">(),</span> <span class="n">distillation_loss_coeff</span><span class="p">:
         <span class="sd">&#39;&#39;&#39;</span>
         <span class="sd">&#39;&#39;&#39;</span>
@@ -114,7 +116,16 @@
         <span class="bp">self</span><span class="o">.</span><span class="n">distillation_loss_fn</span> <span class="o">=</span> <span class="n">distillation_loss_fn</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">distillation_loss_fn</span> <span class="o">=</span> <span class="n">distillation_loss_fn</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">distillation_loss_coeff</span> <span class="o">=</span> <span class="n">distillation_loss_coeff</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">distillation_loss_coeff</span> <span class="o">=</span> <span class="n">distillation_loss_coeff</span>
 
 
-<div class="viewcode-block" id="KDLogitsLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.KDLogitsLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">kd_module_output</span><span class="p">,</span> <span class="n">target</span><span class="p">):</span>
+    <span class="nd">@property</span>
+    <span class="k">def</span> <span class="nf">component_names</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+        <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">        Component names for logging during training.</span>
+<span class="sd">        These correspond to 2nd item in the tuple returned in self.forward(...).</span>
+<span class="sd">        See super_gradients.Trainer.train() docs for more info.</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+        <span class="k">return</span> <span class="p">[</span><span class="s2">&quot;Loss&quot;</span><span class="p">,</span> <span class="s2">&quot;Task Loss&quot;</span><span class="p">,</span> <span class="s2">&quot;Distillation Loss&quot;</span><span class="p">]</span>
+
+<div class="viewcode-block" id="KDLogitsLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.KDLogitsLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">kd_module_output</span><span class="p">,</span> <span class="n">target</span><span class="p">):</span>
         <span class="n">task_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">task_loss_fn</span><span class="p">(</span><span class="n">kd_module_output</span><span class="o">.</span><span class="n">student_output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
         <span class="n">task_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">task_loss_fn</span><span class="p">(</span><span class="n">kd_module_output</span><span class="o">.</span><span class="n">student_output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
         <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">task_loss</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>  <span class="c1"># SOME LOSS FUNCTIONS RETURNS LOSS AND LOG_ITEMS</span>
         <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">task_loss</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>  <span class="c1"># SOME LOSS FUNCTIONS RETURNS LOSS AND LOG_ITEMS</span>
             <span class="n">task_loss</span> <span class="o">=</span> <span class="n">task_loss</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
             <span class="n">task_loss</span> <span class="o">=</span> <span class="n">task_loss</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
@@ -151,4 +162,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.losses.label_smoothing_cross_entropy_loss &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.losses.label_smoothing_cross_entropy_loss &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -91,7 +93,7 @@
 <span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
 <span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
 
 
 
 
-<div class="viewcode-block" id="onehot"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.onehot">[docs]</a><span class="k">def</span> <span class="nf">onehot</span><span class="p">(</span><span class="n">indexes</span><span class="p">,</span> <span class="n">N</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="kc">None</span><span
+<span class="k">def</span> <span class="nf">onehot</span><span class="p">(</span><span class="n">indexes</span><span class="p">,</span> <span class="n">N</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Creates a one-hot representation of indexes with N possible entries</span>
 <span class="sd">    Creates a one-hot representation of indexes with N possible entries</span>
 <span class="sd">    if N is not specified, it will suit the maximum index appearing.</span>
 <span class="sd">    if N is not specified, it will suit the maximum index appearing.</span>
@@ -105,7 +107,7 @@
     <span class="n">output</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">indexes</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span>
     <span class="n">output</span><span class="o">.</span><span class="n">scatter_</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">indexes</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span>
     <span class="k">if</span> <span class="n">ignore_index</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">ignore_index</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">:</span>
     <span class="k">if</span> <span class="n">ignore_index</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">ignore_index</span> <span class="o">&gt;=</span> <span class="mi">0</span><span class="p">:</span>
         <span class="n">output</span><span class="o">.</span><span class="n">masked_fill_</span><span class="p">(</span><span class="n">indexes</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">ignore_index</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="mi">0</span><span class="p">)</span>
         <span class="n">output</span><span class="o">.</span><span class="n">masked_fill_</span><span class="p">(</span><span class="n">indexes</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">ignore_index</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="mi">0</span><span class="p">)</span>
-    <span class="k">return</span> <span class="n">output</span></div>
+    <span class="k">return</span> <span class="n">output</span>
 
 
 
 
 <span class="k">def</span> <span class="nf">_is_long</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
 <span class="k">def</span> <span class="nf">_is_long</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
@@ -114,7 +116,7 @@
     <span class="k">return</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class
     <span class="k">return</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">LongTensor</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class
 
 
 
 
-<div class="viewcode-block" id="cross_entropy"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.cross_entropy">[docs]</a><span class="k">def</span> <span class="nf">cross_entropy</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">weight</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="
+<span class="k">def</span> <span class="nf">cross_entropy</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">weight</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=-</span><span class="mi">100</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;mean&#3
                   <span class="n">smooth_eps</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">smooth_dist</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
                   <span class="n">smooth_eps</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">smooth_dist</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;cross entropy loss, with support for target distributions and label smoothing https://arxiv.org/abs/1512.00567&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;cross entropy loss, with support for target distributions and label smoothing https://arxiv.org/abs/1512.00567&quot;&quot;&quot;</span>
     <span class="n">smooth_eps</span> <span class="o">=</span> <span class="n">smooth_eps</span> <span class="ow">or</span> <span class="mi">0</span>
     <span class="n">smooth_eps</span> <span class="o">=</span> <span class="n">smooth_eps</span> <span class="ow">or</span> <span class="mi">0</span>
@@ -166,10 +168,10 @@
         <span class="k">else</span><span class="p">:</span>
         <span class="k">else</span><span class="p">:</span>
             <span class="n">loss</span> <span class="o">=</span> <span class="n">loss</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="n">loss</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">-</span> <span class="n">masked_indices</span><span class="o">.</span><span class="n"
             <span class="n">loss</span> <span class="o">=</span> <span class="n">loss</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="n">loss</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">-</span> <span class="n">masked_indices</span><span class="o">.</span><span class="n"
 
 
-    <span class="k">return</span> <span class="n">loss</span></div>
+    <span class="k">return</span> <span class="n">loss</span>
 
 
 
 
-<div class="viewcode-block" id="LabelSmoothingCrossEntropyLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.LabelSmoothingCrossEntropyLoss">[docs]</a><span class="k">class</span> <span class="nc">LabelSmoothingCrossEntropyLoss</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">):</span>
+<div class="viewcode-block" id="LabelSmoothingCrossEntropyLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.LabelSmoothingCrossEntropyLoss">[docs]</a><span class="k">class</span> <span class="nc">LabelSmoothingCrossEntropyLoss</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;CrossEntropyLoss - with ability to recieve distrbution as targets, and optional label smoothing&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;CrossEntropyLoss - with ability to recieve distrbution as targets, and optional label smoothing&quot;&quot;&quot;</span>
 
 
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">weight</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=-</span><span class="mi">100</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;mean&#39;</span><span class="p">,</span> <span class="n">smooth
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">weight</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=-</span><span class="mi">100</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;mean&#39;</span><span class="p">,</span> <span class="n">smooth
@@ -180,7 +182,7 @@
         <span class="bp">self</span><span class="o">.</span><span class="n">smooth_dist</span> <span class="o">=</span> <span class="n">smooth_dist</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">smooth_dist</span> <span class="o">=</span> <span class="n">smooth_dist</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">from_logits</span> <span class="o">=</span> <span class="n">from_logits</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">from_logits</span> <span class="o">=</span> <span class="n">from_logits</span>
 
 
-<div class="viewcode-block" id="LabelSmoothingCrossEntropyLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.LabelSmoothingCrossEntropyLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">sm
+<div class="viewcode-block" id="LabelSmoothingCrossEntropyLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.LabelSmoothingCrossEntropyLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">smooth_di
         <span class="k">if</span> <span class="n">smooth_dist</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">smooth_dist</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
             <span class="n">smooth_dist</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">smooth_dist</span>
             <span class="n">smooth_dist</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">smooth_dist</span>
         <span class="n">loss</span> <span class="o">=</span> <span class="n">cross_entropy</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">weight</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span>
         <span class="n">loss</span> <span class="o">=</span> <span class="n">cross_entropy</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">weight</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span>
@@ -218,4 +220,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.losses.ohem_ce_loss &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.losses.ohem_ce_loss</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.losses.ohem_ce_loss</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">torch</span>
  84. <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
  85. <span class="kn">from</span> <span class="nn">torch.nn.modules.loss</span> <span class="kn">import</span> <span class="n">_Loss</span>
  86. <span class="kn">from</span> <span class="nn">super_gradients.training.exceptions.loss_exceptions</span> <span class="kn">import</span> <span class="n">IllegalRangeForLossAttributeException</span><span class="p">,</span> <span class="n">RequiredLossComponentReductionException</span>
  87. <div class="viewcode-block" id="OhemLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ohem_ce_loss.OhemLoss">[docs]</a><span class="k">class</span> <span class="nc">OhemLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
  88. <span class="sd">&quot;&quot;&quot;</span>
  89. <span class="sd"> OhemLoss - Online Hard Example Mining Cross Entropy Loss</span>
  90. <span class="sd"> &quot;&quot;&quot;</span>
  91. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
  92. <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
  93. <span class="n">mining_percent</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
  94. <span class="n">ignore_lb</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">100</span><span class="p">,</span>
  95. <span class="n">num_pixels_exclude_ignored</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
  96. <span class="n">criteria</span><span class="p">:</span> <span class="n">_Loss</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  97. <span class="sd">&quot;&quot;&quot;</span>
  98. <span class="sd"> :param threshold: Sample below probability threshold, is considered hard.</span>
  99. <span class="sd"> :param num_pixels_exclude_ignored: How to calculate total pixels from which extract mining percent of the</span>
  100. <span class="sd"> samples.</span>
  101. <span class="sd"> :param ignore_lb: label index to be ignored in loss calculation.</span>
  102. <span class="sd"> :param criteria: loss to mine the examples from.</span>
  103. <span class="sd"> i.e for num_pixels=100, ignore_pixels=30, mining_percent=0.1:</span>
  104. <span class="sd"> num_pixels_exclude_ignored=False =&gt; num_mining = 100 * 0.1 = 10</span>
  105. <span class="sd"> num_pixels_exclude_ignored=True =&gt; num_mining = (100 - 30) * 0.1 = 7</span>
  106. <span class="sd"> &quot;&quot;&quot;</span>
  107. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
  108. <span class="k">if</span> <span class="n">mining_percent</span> <span class="o">&lt;</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">mining_percent</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
  109. <span class="k">raise</span> <span class="n">IllegalRangeForLossAttributeException</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="s2">&quot;mining percent&quot;</span><span class="p">)</span>
  110. <span class="bp">self</span><span class="o">.</span><span class="n">thresh</span> <span class="o">=</span> <span class="o">-</span><span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">threshold</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">))</span>
  111. <span class="bp">self</span><span class="o">.</span><span class="n">mining_percent</span> <span class="o">=</span> <span class="n">mining_percent</span>
  112. <span class="bp">self</span><span class="o">.</span><span class="n">ignore_lb</span> <span class="o">=</span> <span class="n">ignore_lb</span>
  113. <span class="bp">self</span><span class="o">.</span><span class="n">num_pixels_exclude_ignored</span> <span class="o">=</span> <span class="n">num_pixels_exclude_ignored</span>
  114. <span class="k">if</span> <span class="n">criteria</span><span class="o">.</span><span class="n">reduction</span> <span class="o">!=</span> <span class="s1">&#39;none&#39;</span><span class="p">:</span>
  115. <span class="k">raise</span> <span class="n">RequiredLossComponentReductionException</span><span class="p">(</span><span class="s2">&quot;criteria&quot;</span><span class="p">,</span> <span class="n">criteria</span><span class="o">.</span><span class="n">reduction</span><span class="p">,</span> <span class="s1">&#39;none&#39;</span><span class="p">)</span>
  116. <span class="bp">self</span><span class="o">.</span><span class="n">criteria</span> <span class="o">=</span> <span class="n">criteria</span>
  117. <div class="viewcode-block" id="OhemLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ohem_ce_loss.OhemLoss.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span>
  118. <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">criteria</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
  119. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_pixels_exclude_ignored</span><span class="p">:</span>
  120. <span class="c1"># remove ignore label elements</span>
  121. <span class="n">loss</span> <span class="o">=</span> <span class="n">loss</span><span class="p">[</span><span class="n">labels</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">!=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ignore_lb</span><span class="p">]</span>
  122. <span class="c1"># num pixels in a batch -&gt; num_pixels = batch_size * width * height - ignore_pixels</span>
  123. <span class="n">num_pixels</span> <span class="o">=</span> <span class="n">loss</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>
  124. <span class="k">else</span><span class="p">:</span>
  125. <span class="n">num_pixels</span> <span class="o">=</span> <span class="n">labels</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>
  126. <span class="c1"># if all pixels are ignore labels, return empty loss tensor</span>
  127. <span class="k">if</span> <span class="n">num_pixels</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  128. <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">0.</span><span class="p">])</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span>
  129. <span class="n">num_mining</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mining_percent</span> <span class="o">*</span> <span class="n">num_pixels</span><span class="p">)</span>
  130. <span class="c1"># in case mining_percent=1, prevent out of bound exception</span>
  131. <span class="n">num_mining</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">num_mining</span><span class="p">,</span> <span class="n">num_pixels</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
  132. <span class="bp">self</span><span class="o">.</span><span class="n">thresh</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">thresh</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">logits</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
  133. <span class="n">loss</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">descending</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  134. <span class="k">if</span> <span class="n">loss</span><span class="p">[</span><span class="n">num_mining</span><span class="p">]</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">thresh</span><span class="p">:</span>
  135. <span class="n">loss</span> <span class="o">=</span> <span class="n">loss</span><span class="p">[</span><span class="n">loss</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">thresh</span><span class="p">]</span>
  136. <span class="k">else</span><span class="p">:</span>
  137. <span class="n">loss</span> <span class="o">=</span> <span class="n">loss</span><span class="p">[:</span><span class="n">num_mining</span><span class="p">]</span>
  138. <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span></div></div>
  139. <div class="viewcode-block" id="OhemCELoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ohem_ce_loss.OhemCELoss">[docs]</a><span class="k">class</span> <span class="nc">OhemCELoss</span><span class="p">(</span><span class="n">OhemLoss</span><span class="p">):</span>
  140. <span class="sd">&quot;&quot;&quot;</span>
  141. <span class="sd"> OhemLoss - Online Hard Example Mining Cross Entropy Loss</span>
  142. <span class="sd"> &quot;&quot;&quot;</span>
  143. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
  144. <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
  145. <span class="n">mining_percent</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
  146. <span class="n">ignore_lb</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">100</span><span class="p">,</span>
  147. <span class="n">num_pixels_exclude_ignored</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
  148. <span class="n">ignore_lb</span> <span class="o">=</span> <span class="o">-</span><span class="mi">100</span> <span class="k">if</span> <span class="n">ignore_lb</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="n">ignore_lb</span> <span class="o">&lt;</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">ignore_lb</span>
  149. <span class="n">criteria</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">(</span><span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_lb</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;none&#39;</span><span class="p">)</span>
  150. <span class="nb">super</span><span class="p">(</span><span class="n">OhemCELoss</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">,</span>
  151. <span class="n">mining_percent</span><span class="o">=</span><span class="n">mining_percent</span><span class="p">,</span>
  152. <span class="n">ignore_lb</span><span class="o">=</span><span class="n">ignore_lb</span><span class="p">,</span>
  153. <span class="n">num_pixels_exclude_ignored</span><span class="o">=</span><span class="n">num_pixels_exclude_ignored</span><span class="p">,</span>
  154. <span class="n">criteria</span><span class="o">=</span><span class="n">criteria</span><span class="p">)</span></div>
  155. <div class="viewcode-block" id="OhemBCELoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ohem_ce_loss.OhemBCELoss">[docs]</a><span class="k">class</span> <span class="nc">OhemBCELoss</span><span class="p">(</span><span class="n">OhemLoss</span><span class="p">):</span>
  156. <span class="sd">&quot;&quot;&quot;</span>
  157. <span class="sd"> OhemBCELoss - Online Hard Example Mining Binary Cross Entropy Loss</span>
  158. <span class="sd"> &quot;&quot;&quot;</span>
  159. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
  160. <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
  161. <span class="n">mining_percent</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
  162. <span class="n">ignore_lb</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">100</span><span class="p">,</span>
  163. <span class="n">num_pixels_exclude_ignored</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="p">):</span>
  164. <span class="nb">super</span><span class="p">(</span><span class="n">OhemBCELoss</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">,</span>
  165. <span class="n">mining_percent</span><span class="o">=</span><span class="n">mining_percent</span><span class="p">,</span>
  166. <span class="n">ignore_lb</span><span class="o">=</span><span class="n">ignore_lb</span><span class="p">,</span>
  167. <span class="n">num_pixels_exclude_ignored</span><span class="o">=</span><span class="n">num_pixels_exclude_ignored</span><span class="p">,</span>
  168. <span class="n">criteria</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">(</span><span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;none&#39;</span><span class="p">))</span>
  169. <div class="viewcode-block" id="OhemBCELoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ohem_ce_loss.OhemBCELoss.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span>
  170. <span class="c1"># REMOVE SINGLE CLASS CHANNEL WHEN DEALING WITH BINARY DATA</span>
  171. <span class="k">if</span> <span class="n">logits</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
  172. <span class="n">logits</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
  173. <span class="k">return</span> <span class="nb">super</span><span class="p">(</span><span class="n">OhemBCELoss</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="o">.</span><span class="n">float</span><span class="p">())</span></div></div>
  174. </pre></div>
  175. </div>
  176. </div>
  177. <footer>
  178. <hr/>
  179. <div role="contentinfo">
  180. <p>&#169; Copyright 2021, SuperGradients team.</p>
  181. </div>
  182. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  183. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  184. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  185. </footer>
  186. </div>
  187. </div>
  188. </section>
  189. </div>
  190. <script>
  191. jQuery(function () {
  192. SphinxRtdTheme.Navigation.enable(true);
  193. });
  194. </script>
  195. </body>
  196. </html>
Discard
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.losses.r_squared_loss &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.losses.r_squared_loss &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -95,9 +97,9 @@
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">convert_to_tensor</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">convert_to_tensor</span>
 
 
 
 
-<div class="viewcode-block" id="RSquaredLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.RSquaredLoss">[docs]</a><span class="k">class</span> <span class="nc">RSquaredLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
+<div class="viewcode-block" id="RSquaredLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.RSquaredLoss">[docs]</a><span class="k">class</span> <span class="nc">RSquaredLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
 
 
-<div class="viewcode-block" id="RSquaredLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.RSquaredLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">):</span>
+<div class="viewcode-block" id="RSquaredLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.RSquaredLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">):</span>
         <span class="c1"># FIXME - THIS NEEDS TO BE CHANGED SUCH THAT THIS CLASS INHERETS FROM _Loss (TAKE A LOOK AT YoLoV3DetectionLoss)</span>
         <span class="c1"># FIXME - THIS NEEDS TO BE CHANGED SUCH THAT THIS CLASS INHERETS FROM _Loss (TAKE A LOOK AT YoLoV3DetectionLoss)</span>
         <span class="sd">&quot;&quot;&quot;Computes the R-squared for the output and target values</span>
         <span class="sd">&quot;&quot;&quot;Computes the R-squared for the output and target values</span>
 <span class="sd">        :param output: Tensor / Numpy / List</span>
 <span class="sd">        :param output: Tensor / Numpy / List</span>
@@ -140,4 +142,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.losses.shelfnet_ohem_loss &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.losses.shelfnet_ohem_loss &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -91,7 +93,7 @@
 <span class="kn">from</span> <span class="nn">super_gradients.training.losses.ohem_ce_loss</span> <span class="kn">import</span> <span class="n">OhemCELoss</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.losses.ohem_ce_loss</span> <span class="kn">import</span> <span class="n">OhemCELoss</span>
 
 
 
 
-<div class="viewcode-block" id="ShelfNetOHEMLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ShelfNetOHEMLoss">[docs]</a><span class="k">class</span> <span class="nc">ShelfNetOHEMLoss</span><span class="p">(</span><span class="n">OhemCELoss</span><span class="p">):</span>
+<div class="viewcode-block" id="ShelfNetOHEMLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.ShelfNetOHEMLoss">[docs]</a><span class="k">class</span> <span class="nc">ShelfNetOHEMLoss</span><span class="p">(</span><span class="n">OhemCELoss</span><span class="p">):</span>
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.7</span><span class="p">,</span> <span class="n">mining_percent</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-4</span><span class="p">,</span> <span class="n">ignore_lb
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.7</span><span class="p">,</span> <span class="n">mining_percent</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-4</span><span class="p">,</span> <span class="n">ignore_lb
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        This loss is an extension of the Ohem (Online Hard Example Mining Cross Entropy) Loss.</span>
 <span class="sd">        This loss is an extension of the Ohem (Online Hard Example Mining Cross Entropy) Loss.</span>
@@ -104,14 +106,23 @@
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">,</span> <span class="n">mining_percent</span><span class="o">=</span><span class="n">mining_percent</span><span class="p">,</span> <span class="n">ignore_lb</span><span class="o">=</span><span class="n">ignore_lb</span><span class="p">)</span>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">,</span> <span class="n">mining_percent</span><span class="o">=</span><span class="n">mining_percent</span><span class="p">,</span> <span class="n">ignore_lb</span><span class="o">=</span><span class="n">ignore_lb</span><span class="p">)</span>
 
 
-<div class="viewcode-block" id="ShelfNetOHEMLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ShelfNetOHEMLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions_list</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">targets</span><span c
+<div class="viewcode-block" id="ShelfNetOHEMLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.ShelfNetOHEMLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions_list</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">targets</span><span class="p
         <span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
         <span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
         <span class="k">for</span> <span class="n">predictions</span> <span class="ow">in</span> <span class="n">predictions_list</span><span class="p">:</span>
         <span class="k">for</span> <span class="n">predictions</span> <span class="ow">in</span> <span class="n">predictions_list</span><span class="p">:</span>
             <span class="n">losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">predictions</span><span class="p">,</span> <span class="n">targets</span><span class="p">))</span>
             <span class="n">losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">predictions</span><span class="p">,</span> <span class="n">targets</span><span class="p">))</span>
         <span class="n">total_loss</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">losses</span><span class="p">)</span>
         <span class="n">total_loss</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">losses</span><span class="p">)</span>
         <span class="n">losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">total_loss</span><span class="p">)</span>
         <span class="n">losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">total_loss</span><span class="p">)</span>
 
 
-        <span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">losses</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></div></div>
+        <span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">losses</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></div>
+
+    <span class="nd">@property</span>
+    <span class="k">def</span> <span class="nf">component_names</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+        <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">        Component names for logging during training.</span>
+<span class="sd">        These correspond to 2nd item in the tuple returned in self.forward(...).</span>
+<span class="sd">        See super_gradients.Trainer.train() docs for more info.</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+        <span class="k">return</span> <span class="p">[</span><span class="s2">&quot;Loss1/4&quot;</span><span class="p">,</span> <span class="s2">&quot;Loss1/8&quot;</span><span class="p">,</span> <span class="s2">&quot;Loss1/16&quot;</span><span class="p">,</span> <span class="s2">&quot;Loss&quot;</span><span class="p">]</span></div>
 </pre></div>
 </pre></div>
 
 
            </div>
            </div>
@@ -141,4 +152,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.losses.shelfnet_semantic_encoding_loss &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.losses.shelfnet_semantic_encoding_loss &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -91,7 +93,7 @@
 <span class="kn">from</span> <span class="nn">torch.autograd</span> <span class="kn">import</span> <span class="n">Variable</span>
 <span class="kn">from</span> <span class="nn">torch.autograd</span> <span class="kn">import</span> <span class="n">Variable</span>
 
 
 
 
-<div class="viewcode-block" id="ShelfNetSemanticEncodingLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ShelfNetSemanticEncodingLoss">[docs]</a><span class="k">class</span> <span class="nc">ShelfNetSemanticEncodingLoss</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">):</span>
+<div class="viewcode-block" id="ShelfNetSemanticEncodingLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.ShelfNetSemanticEncodingLoss">[docs]</a><span class="k">class</span> <span class="nc">ShelfNetSemanticEncodingLoss</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;2D Cross Entropy Loss with Auxilary Loss&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;2D Cross Entropy Loss with Auxilary Loss&quot;&quot;&quot;</span>
 
 
     <span class="c1"># FIXME - THIS LOSS SHOULD BE CHANGED TO SUPPORT APEX</span>
     <span class="c1"># FIXME - THIS LOSS SHOULD BE CHANGED TO SUPPORT APEX</span>
@@ -104,7 +106,7 @@
         <span class="c1"># FIXME - TEST CODE LOTEM, CHANGED IN ORDER TO WORK WITH apex.amp</span>
         <span class="c1"># FIXME - TEST CODE LOTEM, CHANGED IN ORDER TO WORK WITH apex.amp</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">bcewithlogitsloss</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCELoss</span><span class="p">(</span><span class="n">weight</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">bcewithlogitsloss</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCELoss</span><span class="p">(</span><span class="n">weight</span><span class="p">)</span>
 
 
-<div class="viewcode-block" id="ShelfNetSemanticEncodingLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ShelfNetSemanticEncodingLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span>
+<div class="viewcode-block" id="ShelfNetSemanticEncodingLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.ShelfNetSemanticEncodingLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span>
         <span class="n">pred1</span><span class="p">,</span> <span class="n">se_pred</span><span class="p">,</span> <span class="n">pred2</span> <span class="o">=</span> <span class="n">logits</span>
         <span class="n">pred1</span><span class="p">,</span> <span class="n">se_pred</span><span class="p">,</span> <span class="n">pred2</span> <span class="o">=</span> <span class="n">logits</span>
 
 
         <span class="n">batch</span> <span class="o">=</span> <span class="n">labels</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
         <span class="n">batch</span> <span class="o">=</span> <span class="n">labels</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
@@ -122,7 +124,16 @@
         <span class="n">loss3</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bcewithlogitsloss</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">se_pred</span><span class="p">),</span> <span class="n">se_target</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">cuda</span><span class="p">())</span
         <span class="n">loss3</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bcewithlogitsloss</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">se_pred</span><span class="p">),</span> <span class="n">se_target</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">cuda</span><span class="p">())</span
         <span class="n">total_loss</span> <span class="o">=</span> <span class="n">loss1</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">aux_weight</span> <span class="o">*</span> <span class="n">loss2</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">se_weight</span> <span class="o">*</span> <span class="n">loss3</span>
         <span class="n">total_loss</span> <span class="o">=</span> <span class="n">loss1</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">aux_weight</span> <span class="o">*</span> <span class="n">loss2</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">se_weight</span> <span class="o">*</span> <span class="n">loss3</span>
         <span class="n">losses</span> <span class="o">=</span> <span class="p">[</span><span class="n">loss1</span><span class="p">,</span> <span class="n">loss2</span><span class="p">,</span> <span class="n">loss3</span><span class="p">,</span> <span class="n">total_loss</span><span class="p">]</span>
         <span class="n">losses</span> <span class="o">=</span> <span class="p">[</span><span class="n">loss1</span><span class="p">,</span> <span class="n">loss2</span><span class="p">,</span> <span class="n">loss3</span><span class="p">,</span> <span class="n">total_loss</span><span class="p">]</span>
-        <span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">losses</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></div></div>
+        <span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">losses</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></div>
+
+    <span class="nd">@property</span>
+    <span class="k">def</span> <span class="nf">component_names</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+        <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">        Component names for logging during training.</span>
+<span class="sd">        These correspond to 2nd item in the tuple returned in self.forward(...).</span>
+<span class="sd">        See super_gradients.Trainer.train() docs for more info.</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+        <span class="k">return</span> <span class="p">[</span><span class="s2">&quot;loss1&quot;</span><span class="p">,</span> <span class="s2">&quot;loss2&quot;</span><span class="p">,</span> <span class="s2">&quot;loss3&quot;</span><span class="p">,</span> <span class="s2">&quot;total_loss&quot;</span><span class="p">]</span></div>
 </pre></div>
 </pre></div>
 
 
            </div>
            </div>
@@ -152,4 +163,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.losses.ssd_loss &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.losses.ssd_loss &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -96,7 +98,7 @@
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.ssd_utils</span> <span class="kn">import</span> <span class="n">DefaultBoxes</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.ssd_utils</span> <span class="kn">import</span> <span class="n">DefaultBoxes</span>
 
 
 
 
-<div class="viewcode-block" id="HardMiningCrossEntropyLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.HardMiningCrossEntropyLoss">[docs]</a><span class="k">class</span> <span class="nc">HardMiningCrossEntropyLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
+<span class="k">class</span> <span class="nc">HardMiningCrossEntropyLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    L_cls = [CE of all positives] + [CE of the hardest backgrounds]</span>
 <span class="sd">    L_cls = [CE of all positives] + [CE of the hardest backgrounds]</span>
 <span class="sd">    where the second term is built from [neg_pos_ratio * positive pairs] background cells with the highest CE</span>
 <span class="sd">    where the second term is built from [neg_pos_ratio * positive pairs] background cells with the highest CE</span>
@@ -113,7 +115,7 @@
         <span class="bp">self</span><span class="o">.</span><span class="n">neg_pos_ratio</span> <span class="o">=</span> <span class="n">neg_pos_ratio</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">neg_pos_ratio</span> <span class="o">=</span> <span class="n">neg_pos_ratio</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">ce</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">(</span><span class="n">reduce</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">ce</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">(</span><span class="n">reduce</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
 
 
-<div class="viewcode-block" id="HardMiningCrossEntropyLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.HardMiningCrossEntropyLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pred_labels</span><span class="p">,</span> <span class="n">target_labels</span><span class="p">):</span>
+    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pred_labels</span><span class="p">,</span> <span class="n">target_labels</span><span class="p">):</span>
         <span class="n">mask</span> <span class="o">=</span> <span class="n">target_labels</span> <span class="o">&gt;</span> <span class="mi">0</span>  <span class="c1"># not background</span>
         <span class="n">mask</span> <span class="o">=</span> <span class="n">target_labels</span> <span class="o">&gt;</span> <span class="mi">0</span>  <span class="c1"># not background</span>
         <span class="n">pos_num</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
         <span class="n">pos_num</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
 
 
@@ -135,10 +137,10 @@
         <span class="n">neg_mask</span> <span class="o">=</span> <span class="n">con_rank</span> <span class="o">&lt;</span> <span class="n">neg_num</span>
         <span class="n">neg_mask</span> <span class="o">=</span> <span class="n">con_rank</span> <span class="o">&lt;</span> <span class="n">neg_num</span>
 
 
         <span class="n">closs</span> <span class="o">=</span> <span class="p">(</span><span class="n">con</span> <span class="o">*</span> <span class="p">(</span><span class="n">mask</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="o">+</span> <span class="n">neg_mask</span><span class="o">.</span><span class="n">float</span><span class="p">()))</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</sp
         <span class="n">closs</span> <span class="o">=</span> <span class="p">(</span><span class="n">con</span> <span class="o">*</span> <span class="p">(</span><span class="n">mask</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="o">+</span> <span class="n">neg_mask</span><span class="o">.</span><span class="n">float</span><span class="p">()))</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</sp
-        <span class="k">return</span> <span class="n">closs</span></div></div>
+        <span class="k">return</span> <span class="n">closs</span>
 
 
 
 
-<div class="viewcode-block" id="SSDLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.SSDLoss">[docs]</a><span class="k">class</span> <span class="nc">SSDLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
+<div class="viewcode-block" id="SSDLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.SSDLoss">[docs]</a><span class="k">class</span> <span class="nc">SSDLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        Implements the loss as the sum of the followings:</span>
 <span class="sd">        Implements the loss as the sum of the followings:</span>
 <span class="sd">        1. Confidence Loss: All labels, with hard negative mining</span>
 <span class="sd">        1. Confidence Loss: All labels, with hard negative mining</span>
@@ -167,6 +169,15 @@
         <span class="bp">self</span><span class="o">.</span><span class="n">con_loss</span> <span class="o">=</span> <span class="n">HardMiningCrossEntropyLoss</span><span class="p">(</span><span class="n">neg_pos_ratio</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">con_loss</span> <span class="o">=</span> <span class="n">HardMiningCrossEntropyLoss</span><span class="p">(</span><span class="n">neg_pos_ratio</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresh</span> <span class="o">=</span> <span class="n">iou_thresh</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresh</span> <span class="o">=</span> <span class="n">iou_thresh</span>
 
 
+    <span class="nd">@property</span>
+    <span class="k">def</span> <span class="nf">component_names</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+        <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">        Component names for logging during training.</span>
+<span class="sd">        These correspond to 2nd item in the tuple returned in self.forward(...).</span>
+<span class="sd">        See super_gradients.Trainer.train() docs for more info.</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+        <span class="k">return</span> <span class="p">[</span><span class="s2">&quot;smooth_l1&quot;</span><span class="p">,</span> <span class="s2">&quot;closs&quot;</span><span class="p">,</span> <span class="s2">&quot;Loss&quot;</span><span class="p">]</span>
+
     <span class="k">def</span> <span class="nf">_norm_relative_bbox</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">loc</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">_norm_relative_bbox</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">loc</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        convert bbox locations into relative locations (relative to the dboxes)</span>
 <span class="sd">        convert bbox locations into relative locations (relative to the dboxes)</span>
@@ -176,7 +187,7 @@
         <span class="n">gwh</span> <span class="o">=</span> <span class="p">(</span><span class="n">loc</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:,</span> <span class="p">:]</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:,</span> <span class="p">:])</span><span class="o">.</span><span class="n">log</span><span class="p">()</span>
         <span class="n">gwh</span> <span class="o">=</span> <span class="p">(</span><span class="n">loc</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:,</span> <span class="p">:]</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:,</span> <span class="p">:])</span><span class="o">.</span><span class="n">log</span><span class="p">()</span>
         <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">gxy</span><span class="p">,</span> <span class="n">gwh</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
         <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">gxy</span><span class="p">,</span> <span class="n">gwh</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
 
 
-<div class="viewcode-block" id="SSDLoss.match_dboxes"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.SSDLoss.match_dboxes">[docs]</a>    <span class="k">def</span> <span class="nf">match_dboxes</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span>
+<div class="viewcode-block" id="SSDLoss.match_dboxes"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.SSDLoss.match_dboxes">[docs]</a>    <span class="k">def</span> <span class="nf">match_dboxes</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        creates tensors with target boxes and labels for each dboxes, so with the same len as dboxes.</span>
 <span class="sd">        creates tensors with target boxes and labels for each dboxes, so with the same len as dboxes.</span>
 
 
@@ -220,7 +231,7 @@
 
 
         <span class="k">return</span> <span class="n">each_cell_target_locations</span><span class="p">,</span> <span class="n">each_cell_target_labels</span></div>
         <span class="k">return</span> <span class="n">each_cell_target_locations</span><span class="p">,</span> <span class="n">each_cell_target_labels</span></div>
 
 
-<div class="viewcode-block" id="SSDLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.SSDLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span>
+<div class="viewcode-block" id="SSDLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.SSDLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        Compute the loss</span>
 <span class="sd">        Compute the loss</span>
 <span class="sd">            :param predictions - predictions tensor coming from the network,</span>
 <span class="sd">            :param predictions - predictions tensor coming from the network,</span>
@@ -289,4 +300,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.losses.yolo_v3_loss &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">SuperGradients</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common</a></li>
  40. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training</a></li>
  41. </ul>
  42. </div>
  43. </div>
  44. </nav>
  45. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  46. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  47. <a href="../../../../index.html">SuperGradients</a>
  48. </nav>
  49. <div class="wy-nav-content">
  50. <div class="rst-content">
  51. <div role="navigation" aria-label="Page navigation">
  52. <ul class="wy-breadcrumbs">
  53. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  54. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  55. <li>super_gradients.training.losses.yolo_v3_loss</li>
  56. <li class="wy-breadcrumbs-aside">
  57. </li>
  58. </ul>
  59. <hr/>
  60. </div>
  61. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  62. <div itemprop="articleBody">
  63. <h1>Source code for super_gradients.training.losses.yolo_v3_loss</h1><div class="highlight"><pre>
  64. <span></span><span class="kn">import</span> <span class="nn">torch</span>
  65. <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
  66. <span class="kn">from</span> <span class="nn">torch.nn.modules.loss</span> <span class="kn">import</span> <span class="n">_Loss</span>
  67. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">build_detection_targets</span><span class="p">,</span> <span class="n">calculate_bbox_iou_elementwise</span>
  68. <div class="viewcode-block" id="YoLoV3DetectionLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.yolo_v3_loss.YoLoV3DetectionLoss">[docs]</a><span class="k">class</span> <span class="nc">YoLoV3DetectionLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
  69. <span class="sd">&quot;&quot;&quot;</span>
  70. <span class="sd"> YoLoV3DetectionLoss - Loss Class for Object Detection</span>
  71. <span class="sd"> &quot;&quot;&quot;</span>
  72. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">cls_pw</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.</span><span class="p">,</span> <span class="n">obj_pw</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.</span><span class="p">,</span> <span class="n">giou</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">3.54</span><span class="p">,</span> <span class="n">obj</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">64.3</span><span class="p">,</span>
  73. <span class="bp">cls</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">37.4</span><span class="p">):</span>
  74. <span class="nb">super</span><span class="p">(</span><span class="n">YoLoV3DetectionLoss</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
  75. <span class="bp">self</span><span class="o">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">model</span>
  76. <span class="bp">self</span><span class="o">.</span><span class="n">cls_pw</span> <span class="o">=</span> <span class="n">cls_pw</span>
  77. <span class="bp">self</span><span class="o">.</span><span class="n">obj_pw</span> <span class="o">=</span> <span class="n">obj_pw</span>
  78. <span class="bp">self</span><span class="o">.</span><span class="n">giou</span> <span class="o">=</span> <span class="n">giou</span>
  79. <span class="bp">self</span><span class="o">.</span><span class="n">obj</span> <span class="o">=</span> <span class="n">obj</span>
  80. <span class="bp">self</span><span class="o">.</span><span class="n">cls</span> <span class="o">=</span> <span class="bp">cls</span>
  81. <span class="bp">self</span><span class="o">.</span><span class="n">classes_num</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">num_classes</span>
  82. <div class="viewcode-block" id="YoLoV3DetectionLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.yolo_v3_loss.YoLoV3DetectionLoss.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model_output</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span>
  83. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">model_output</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">model_output</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
  84. <span class="c1"># in test/eval mode the Yolo v3 model output a tuple where the second item is the raw predictions</span>
  85. <span class="n">_</span><span class="p">,</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model_output</span>
  86. <span class="k">else</span><span class="p">:</span>
  87. <span class="n">predictions</span> <span class="o">=</span> <span class="n">model_output</span>
  88. <span class="n">detection_targets</span> <span class="o">=</span> <span class="n">build_detection_targets</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span>
  89. <span class="n">float_tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">FloatTensor</span> <span class="k">if</span> <span class="n">predictions</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">is_cuda</span> <span class="k">else</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span>
  90. <span class="n">class_loss</span><span class="p">,</span> <span class="n">giou_loss</span><span class="p">,</span> <span class="n">objectness_loss</span> <span class="o">=</span> <span class="n">float_tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">]),</span> <span class="n">float_tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">]),</span> <span class="n">float_tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">])</span>
  91. <span class="n">target_class</span><span class="p">,</span> <span class="n">target_box</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">anchor_vec</span> <span class="o">=</span> <span class="n">detection_targets</span>
  92. <span class="n">reduction</span> <span class="o">=</span> <span class="s1">&#39;mean&#39;</span> <span class="c1"># Loss reduction (sum or mean)</span>
  93. <span class="c1"># DEFINE CRITERIA</span>
  94. <span class="n">BCEcls</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">(</span><span class="n">pos_weight</span><span class="o">=</span><span class="n">float_tensor</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">cls_pw</span><span class="p">]),</span> <span class="n">reduction</span><span class="o">=</span><span class="n">reduction</span><span class="p">)</span>
  95. <span class="n">BCEobj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">(</span><span class="n">pos_weight</span><span class="o">=</span><span class="n">float_tensor</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">obj_pw</span><span class="p">]),</span> <span class="n">reduction</span><span class="o">=</span><span class="n">reduction</span><span class="p">)</span>
  96. <span class="c1"># COMPUTE THE LOSSES BASED ON EACH ONE OF THE YOLO LAYERS PREDICTIONS</span>
  97. <span class="n">grid_points_num</span><span class="p">,</span> <span class="n">targets_num</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span>
  98. <span class="k">for</span> <span class="n">yolo_layer_index</span><span class="p">,</span> <span class="n">yolo_layer_prediction</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">predictions</span><span class="p">):</span>
  99. <span class="n">image</span><span class="p">,</span> <span class="n">anchor</span><span class="p">,</span> <span class="n">grid_y</span><span class="p">,</span> <span class="n">grid_x</span> <span class="o">=</span> <span class="n">indices</span><span class="p">[</span><span class="n">yolo_layer_index</span><span class="p">]</span>
  100. <span class="n">target_object</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">yolo_layer_prediction</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">0</span><span class="p">])</span>
  101. <span class="n">grid_points_num</span> <span class="o">+=</span> <span class="n">target_object</span><span class="o">.</span><span class="n">numel</span><span class="p">()</span>
  102. <span class="c1"># COMPUTE LOSSES</span>
  103. <span class="n">nb</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">image</span><span class="p">)</span>
  104. <span class="k">if</span> <span class="n">nb</span><span class="p">:</span> <span class="c1"># number of targets</span>
  105. <span class="n">targets_num</span> <span class="o">+=</span> <span class="n">nb</span>
  106. <span class="n">predictions_for_targets</span> <span class="o">=</span> <span class="n">yolo_layer_prediction</span><span class="p">[</span><span class="n">image</span><span class="p">,</span> <span class="n">anchor</span><span class="p">,</span> <span class="n">grid_y</span><span class="p">,</span> <span class="n">grid_x</span><span class="p">]</span>
  107. <span class="n">target_object</span><span class="p">[</span><span class="n">image</span><span class="p">,</span> <span class="n">anchor</span><span class="p">,</span> <span class="n">grid_y</span><span class="p">,</span> <span class="n">grid_x</span><span class="p">]</span> <span class="o">=</span> <span class="mf">1.0</span>
  108. <span class="c1"># GIoU LOSS CALCULATION</span>
  109. <span class="n">pxy</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span>
  110. <span class="n">predictions_for_targets</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">2</span><span class="p">])</span> <span class="c1"># pxy = pxy * s - (s - 1) / 2, s = 1.5 (scale_xy)</span>
  111. <span class="n">bbox_prediction</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
  112. <span class="p">(</span><span class="n">pxy</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">predictions_for_targets</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:</span><span class="mi">4</span><span class="p">])</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">max</span><span class="o">=</span><span class="mf">1E3</span><span class="p">)</span> <span class="o">*</span> <span class="n">anchor_vec</span><span class="p">[</span><span class="n">yolo_layer_index</span><span class="p">]),</span> <span class="mi">1</span><span class="p">)</span>
  113. <span class="n">giou</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">calculate_bbox_iou_elementwise</span><span class="p">(</span><span class="n">bbox_prediction</span><span class="o">.</span><span class="n">t</span><span class="p">(),</span> <span class="n">target_box</span><span class="p">[</span><span class="n">yolo_layer_index</span><span class="p">],</span>
  114. <span class="n">x1y1x2y2</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">GIoU</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  115. <span class="n">giou_loss</span> <span class="o">+=</span> <span class="n">giou</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="k">if</span> <span class="n">reduction</span> <span class="o">==</span> <span class="s1">&#39;sum&#39;</span> <span class="k">else</span> <span class="n">giou</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
  116. <span class="c1"># ONLY RELEVANT TO MULTIPLE CLASSES</span>
  117. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">classes_num</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
  118. <span class="n">class_targets</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">predictions_for_targets</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">:])</span>
  119. <span class="n">class_targets</span><span class="p">[</span><span class="nb">range</span><span class="p">(</span><span class="n">nb</span><span class="p">),</span> <span class="n">target_class</span><span class="p">[</span><span class="n">yolo_layer_index</span><span class="p">]]</span> <span class="o">=</span> <span class="mf">1.0</span>
  120. <span class="n">class_loss</span> <span class="o">+=</span> <span class="n">BCEcls</span><span class="p">(</span><span class="n">predictions_for_targets</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">:],</span> <span class="n">class_targets</span><span class="p">)</span>
  121. <span class="n">objectness_loss</span> <span class="o">+=</span> <span class="n">BCEobj</span><span class="p">(</span><span class="n">yolo_layer_prediction</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">4</span><span class="p">],</span> <span class="n">target_object</span><span class="p">)</span>
  122. <span class="k">if</span> <span class="n">reduction</span> <span class="o">==</span> <span class="s1">&#39;sum&#39;</span><span class="p">:</span>
  123. <span class="n">giou_loss</span> <span class="o">*=</span> <span class="mi">3</span> <span class="o">/</span> <span class="n">targets_num</span>
  124. <span class="n">objectness_loss</span> <span class="o">*=</span> <span class="mi">3</span> <span class="o">/</span> <span class="n">grid_points_num</span>
  125. <span class="n">class_loss</span> <span class="o">*=</span> <span class="mi">3</span> <span class="o">/</span> <span class="n">targets_num</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">classes_num</span>
  126. <span class="n">loss</span> <span class="o">=</span> <span class="n">giou_loss</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">giou</span> <span class="o">+</span> <span class="n">objectness_loss</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">obj</span> <span class="o">+</span> <span class="n">class_loss</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">cls</span>
  127. <span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">giou_loss</span><span class="p">,</span> <span class="n">objectness_loss</span><span class="p">,</span> <span class="n">class_loss</span><span class="p">,</span> <span class="n">loss</span><span class="p">))</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></div></div>
  128. </pre></div>
  129. </div>
  130. </div>
  131. <footer>
  132. <hr/>
  133. <div role="contentinfo">
  134. <p>&#169; Copyright 2021, SuperGradients team.</p>
  135. </div>
  136. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  137. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  138. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  139. </footer>
  140. </div>
  141. </div>
  142. </section>
  143. </div>
  144. <script>
  145. jQuery(function () {
  146. SphinxRtdTheme.Navigation.enable(true);
  147. });
  148. </script>
  149. </body>
  150. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.losses.yolo_v5_loss &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">SuperGradients</a></li>
  39. </ul>
  40. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  41. <ul>
  42. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  44. </ul>
  45. <p class="caption"><span class="caption-text">User Guide</span></p>
  46. <ul>
  47. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  58. </ul>
  59. </div>
  60. </div>
  61. </nav>
  62. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  63. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  64. <a href="../../../../index.html">SuperGradients</a>
  65. </nav>
  66. <div class="wy-nav-content">
  67. <div class="rst-content">
  68. <div role="navigation" aria-label="Page navigation">
  69. <ul class="wy-breadcrumbs">
  70. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  71. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  72. <li>super_gradients.training.losses.yolo_v5_loss</li>
  73. <li class="wy-breadcrumbs-aside">
  74. </li>
  75. </ul>
  76. <hr/>
  77. </div>
  78. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  79. <div itemprop="articleBody">
  80. <h1>Source code for super_gradients.training.losses.yolo_v5_loss</h1><div class="highlight"><pre>
  81. <span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
  82. <span class="kn">import</span> <span class="nn">torch</span>
  83. <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
  84. <span class="kn">from</span> <span class="nn">torch.nn.modules.loss</span> <span class="kn">import</span> <span class="n">_Loss</span>
  85. <span class="kn">from</span> <span class="nn">super_gradients.training.losses.focal_loss</span> <span class="kn">import</span> <span class="n">FocalLoss</span>
  86. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">calculate_bbox_iou_elementwise</span><span class="p">,</span> <span class="n">Anchors</span>
  87. <div class="viewcode-block" id="YoLoV5DetectionLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoLoV5DetectionLoss">[docs]</a><span class="k">class</span> <span class="nc">YoLoV5DetectionLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
  88. <span class="sd">&quot;&quot;&quot;</span>
  89. <span class="sd"> Calculate YOLO V5 loss:</span>
  90. <span class="sd"> L = L_objectivness + L_boxes + L_classification</span>
  91. <span class="sd"> &quot;&quot;&quot;</span>
  92. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">anchors</span><span class="p">:</span> <span class="n">Anchors</span><span class="p">,</span>
  93. <span class="n">cls_pos_weight</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span> <span class="n">obj_pos_weight</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
  94. <span class="n">obj_loss_gain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span> <span class="n">box_loss_gain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.05</span><span class="p">,</span> <span class="n">cls_loss_gain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
  95. <span class="n">focal_loss_gamma</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
  96. <span class="n">cls_objectness_weights</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">anchor_threshold</span><span class="o">=</span><span class="mf">4.0</span><span class="p">):</span>
  97. <span class="sd">&quot;&quot;&quot;</span>
  98. <span class="sd"> :param anchors: the anchors of the model (same anchors used for training)</span>
  99. <span class="sd"> :param cls_pos_weight: pos_weight for BCE in L_classification,</span>
  100. <span class="sd"> can be one value for all positives or a list of weights for each class</span>
  101. <span class="sd"> :param obj_pos_weight: pos_weight for BCE in L_objectivness</span>
  102. <span class="sd"> :param obj_loss_gain: coef for L_objectivness</span>
  103. <span class="sd"> :param box_loss_gain: coef for L_boxes</span>
  104. <span class="sd"> :param cls_loss_gain: coef for L_classification</span>
  105. <span class="sd"> :param focal_loss_gamma: gamma for a focal loss, 0 to train with a usual BCE</span>
  106. <span class="sd"> :param cls_objectness_weights: class-based weight for L_objectivness that will be applied in each cell that</span>
  107. <span class="sd"> has a GT assigned to it.</span>
  108. <span class="sd"> Note: default weight for objectness loss in each cell is 1.</span>
  109. <span class="sd"> :param anchor_threshold: ratio defining a size range of an appropriate anchor.</span>
  110. <span class="sd"> &quot;&quot;&quot;</span>
  111. <span class="nb">super</span><span class="p">(</span><span class="n">YoLoV5DetectionLoss</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
  112. <span class="bp">self</span><span class="o">.</span><span class="n">cls_pos_weight</span> <span class="o">=</span> <span class="n">cls_pos_weight</span>
  113. <span class="bp">self</span><span class="o">.</span><span class="n">obj_pos_weight</span> <span class="o">=</span> <span class="n">obj_pos_weight</span>
  114. <span class="bp">self</span><span class="o">.</span><span class="n">obj_loss_gain</span> <span class="o">=</span> <span class="n">obj_loss_gain</span>
  115. <span class="bp">self</span><span class="o">.</span><span class="n">box_loss_gain</span> <span class="o">=</span> <span class="n">box_loss_gain</span>
  116. <span class="bp">self</span><span class="o">.</span><span class="n">cls_loss_gain</span> <span class="o">=</span> <span class="n">cls_loss_gain</span>
  117. <span class="bp">self</span><span class="o">.</span><span class="n">focal_loss_gamma</span> <span class="o">=</span> <span class="n">focal_loss_gamma</span>
  118. <span class="bp">self</span><span class="o">.</span><span class="n">anchor_threshold</span> <span class="o">=</span> <span class="n">anchor_threshold</span>
  119. <span class="bp">self</span><span class="o">.</span><span class="n">anchors</span> <span class="o">=</span> <span class="n">anchors</span>
  120. <span class="bp">self</span><span class="o">.</span><span class="n">cls_obj_weights</span> <span class="o">=</span> <span class="n">cls_objectness_weights</span>
  121. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">cls_objectness_weights</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
  122. <span class="bp">self</span><span class="o">.</span><span class="n">cls_obj_weights</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">cls_objectness_weights</span><span class="p">))</span>
  123. <div class="viewcode-block" id="YoLoV5DetectionLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoLoV5DetectionLoss.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model_output</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span>
  124. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">model_output</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">model_output</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
  125. <span class="c1"># in test/eval mode the Yolo v5 model output a tuple where the second item is the raw predictions</span>
  126. <span class="n">_</span><span class="p">,</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model_output</span>
  127. <span class="k">else</span><span class="p">:</span>
  128. <span class="n">predictions</span> <span class="o">=</span> <span class="n">model_output</span>
  129. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">compute_loss</span><span class="p">(</span><span class="n">predictions</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span></div>
  130. <div class="viewcode-block" id="YoLoV5DetectionLoss.build_targets"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoLoV5DetectionLoss.build_targets">[docs]</a> <span class="k">def</span> <span class="nf">build_targets</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">targets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> \
  131. <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]],</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]]:</span>
  132. <span class="sd">&quot;&quot;&quot;</span>
  133. <span class="sd"> Assign targets to anchors to use in L_boxes &amp; L_classification calculation:</span>
  134. <span class="sd"> * each target can be assigned to a few anchors,</span>
  135. <span class="sd"> all anchors that are within [1/self.anchor_threshold, self.anchor_threshold] times target size range</span>
  136. <span class="sd"> * each anchor can be assigned to a few targets</span>
  137. <span class="sd"> :param predictions: Yolo predictions</span>
  138. <span class="sd"> :param targets: ground truth targets</span>
  139. <span class="sd"> :return: each of 4 outputs contains one element for each Yolo output,</span>
  140. <span class="sd"> correspondences are raveled over the whole batch and all anchors:</span>
  141. <span class="sd"> * classes of the targets;</span>
  142. <span class="sd"> * boxes of the targets;</span>
  143. <span class="sd"> * image id in a batch, anchor id, grid y, grid x coordinates;</span>
  144. <span class="sd"> * anchor sizes.</span>
  145. <span class="sd"> All the above can be indexed in parallel to get the selected correspondences</span>
  146. <span class="sd"> &quot;&quot;&quot;</span>
  147. <span class="n">num_anchors</span><span class="p">,</span> <span class="n">num_targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">anchors</span><span class="o">.</span><span class="n">num_anchors</span><span class="p">,</span> <span class="n">targets</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
  148. <span class="n">target_classes</span><span class="p">,</span> <span class="n">target_boxes</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">anchors</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[],</span> <span class="p">[],</span> <span class="p">[]</span>
  149. <span class="n">gain</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">targets</span><span class="o">.</span><span class="n">device</span><span class="p">)</span> <span class="c1"># normalized to gridspace gain</span>
  150. <span class="n">anchor_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_anchors</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">targets</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
  151. <span class="n">anchor_indices</span> <span class="o">=</span> <span class="n">anchor_indices</span><span class="o">.</span><span class="n">float</span><span class="p">()</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">num_anchors</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_targets</span><span class="p">)</span>
  152. <span class="c1"># repeat all targets for each anchor and append a corresponding anchor index</span>
  153. <span class="n">targets</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">targets</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">num_anchors</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">anchor_indices</span><span class="p">[:,</span> <span class="p">:,</span> <span class="kc">None</span><span class="p">]),</span> <span class="mi">2</span><span class="p">)</span>
  154. <span class="n">bias</span> <span class="o">=</span> <span class="mf">0.5</span>
  155. <span class="n">off</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span>
  156. <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="c1"># j,k,l,m</span>
  157. <span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">targets</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="o">*</span> <span class="n">bias</span> <span class="c1"># offsets</span>
  158. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">anchors</span><span class="o">.</span><span class="n">detection_layers_num</span><span class="p">):</span>
  159. <span class="n">anch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">anchors</span><span class="o">.</span><span class="n">anchors</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
  160. <span class="n">gain</span><span class="p">[</span><span class="mi">2</span><span class="p">:</span><span class="mi">6</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">predictions</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)[[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">]]</span> <span class="c1"># xyxy gain</span>
  161. <span class="c1"># Convert target coordinates from [0, 1] range to coordinates in [0, GridY], [0, GridX] ranges</span>
  162. <span class="n">t</span> <span class="o">=</span> <span class="n">targets</span> <span class="o">*</span> <span class="n">gain</span>
  163. <span class="k">if</span> <span class="n">num_targets</span><span class="p">:</span>
  164. <span class="c1"># Match: filter targets by anchor size ratio</span>
  165. <span class="n">r</span> <span class="o">=</span> <span class="n">t</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">4</span><span class="p">:</span><span class="mi">6</span><span class="p">]</span> <span class="o">/</span> <span class="n">anch</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="c1"># wh ratio</span>
  166. <span class="n">filtered_targets_ids</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">r</span><span class="p">,</span> <span class="mf">1.</span> <span class="o">/</span> <span class="n">r</span><span class="p">)</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="mi">2</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">anchor_threshold</span> <span class="c1"># compare</span>
  167. <span class="n">t</span> <span class="o">=</span> <span class="n">t</span><span class="p">[</span><span class="n">filtered_targets_ids</span><span class="p">]</span>
  168. <span class="c1"># Find coordinates of targets on a grid</span>
  169. <span class="n">gxy</span> <span class="o">=</span> <span class="n">t</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="c1"># grid xy</span>
  170. <span class="n">gxi</span> <span class="o">=</span> <span class="n">gain</span><span class="p">[[</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">]]</span> <span class="o">-</span> <span class="n">gxy</span> <span class="c1"># inverse</span>
  171. <span class="n">j</span><span class="p">,</span> <span class="n">k</span> <span class="o">=</span> <span class="p">((</span><span class="n">gxy</span> <span class="o">%</span> <span class="mf">1.</span> <span class="o">&lt;</span> <span class="n">bias</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">gxy</span> <span class="o">&gt;</span> <span class="mf">1.</span><span class="p">))</span><span class="o">.</span><span class="n">T</span>
  172. <span class="n">l</span><span class="p">,</span> <span class="n">m</span> <span class="o">=</span> <span class="p">((</span><span class="n">gxi</span> <span class="o">%</span> <span class="mf">1.</span> <span class="o">&lt;</span> <span class="n">bias</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">gxi</span> <span class="o">&gt;</span> <span class="mf">1.</span><span class="p">))</span><span class="o">.</span><span class="n">T</span>
  173. <span class="n">j</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">j</span><span class="p">),</span> <span class="n">j</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">l</span><span class="p">,</span> <span class="n">m</span><span class="p">))</span>
  174. <span class="n">t</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">repeat</span><span class="p">((</span><span class="mi">5</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))[</span><span class="n">j</span><span class="p">]</span>
  175. <span class="n">offsets</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">gxy</span><span class="p">)[</span><span class="kc">None</span><span class="p">]</span> <span class="o">+</span> <span class="n">off</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">])[</span><span class="n">j</span><span class="p">]</span>
  176. <span class="k">else</span><span class="p">:</span>
  177. <span class="n">t</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
  178. <span class="n">offsets</span> <span class="o">=</span> <span class="mi">0</span>
  179. <span class="c1"># Define</span>
  180. <span class="n">b</span><span class="p">,</span> <span class="n">c</span> <span class="o">=</span> <span class="n">t</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">long</span><span class="p">()</span><span class="o">.</span><span class="n">T</span> <span class="c1"># image, class</span>
  181. <span class="n">gxy</span> <span class="o">=</span> <span class="n">t</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="c1"># grid xy</span>
  182. <span class="n">gwh</span> <span class="o">=</span> <span class="n">t</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">:</span><span class="mi">6</span><span class="p">]</span> <span class="c1"># grid wh</span>
  183. <span class="n">gij</span> <span class="o">=</span> <span class="p">(</span><span class="n">gxy</span> <span class="o">-</span> <span class="n">offsets</span><span class="p">)</span><span class="o">.</span><span class="n">long</span><span class="p">()</span>
  184. <span class="n">gi</span><span class="p">,</span> <span class="n">gj</span> <span class="o">=</span> <span class="n">gij</span><span class="o">.</span><span class="n">T</span> <span class="c1"># grid xy indices</span>
  185. <span class="c1"># prevent coordinates from going out of bounds</span>
  186. <span class="n">gi</span><span class="p">,</span> <span class="n">gj</span> <span class="o">=</span> <span class="n">gi</span><span class="o">.</span><span class="n">clamp_</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">gain</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">),</span> <span class="n">gj</span><span class="o">.</span><span class="n">clamp_</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">gain</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
  187. <span class="c1"># Append</span>
  188. <span class="n">a</span> <span class="o">=</span> <span class="n">t</span><span class="p">[:,</span> <span class="mi">6</span><span class="p">]</span><span class="o">.</span><span class="n">long</span><span class="p">()</span> <span class="c1"># anchor indices</span>
  189. <span class="n">indices</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">b</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">gj</span><span class="p">,</span> <span class="n">gi</span><span class="p">))</span> <span class="c1"># image, anchor, grid indices</span>
  190. <span class="n">target_boxes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">gxy</span> <span class="o">-</span> <span class="n">gij</span><span class="p">,</span> <span class="n">gwh</span><span class="p">),</span> <span class="mi">1</span><span class="p">))</span> <span class="c1"># box</span>
  191. <span class="n">anchors</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">anch</span><span class="p">[</span><span class="n">a</span><span class="p">])</span> <span class="c1"># anchors</span>
  192. <span class="n">target_classes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">c</span><span class="p">)</span> <span class="c1"># class</span>
  193. <span class="k">return</span> <span class="n">target_classes</span><span class="p">,</span> <span class="n">target_boxes</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">anchors</span></div>
  194. <div class="viewcode-block" id="YoLoV5DetectionLoss.compute_loss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoLoV5DetectionLoss.compute_loss">[docs]</a> <span class="k">def</span> <span class="nf">compute_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">targets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">giou_loss_ratio</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">)</span> \
  195. <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]:</span>
  196. <span class="sd">&quot;&quot;&quot;</span>
  197. <span class="sd"> L = L_objectivness + L_boxes + L_classification</span>
  198. <span class="sd"> where:</span>
  199. <span class="sd"> * L_boxes and L_classification are calculated only between anchors and targets that suit them;</span>
  200. <span class="sd"> * L_objectivness is calculated on all anchors.</span>
  201. <span class="sd"> L_classification:</span>
  202. <span class="sd"> for anchors that have suitable ground truths in their grid locations add BCEs</span>
  203. <span class="sd"> to force max probability for each GT class in a multi-label way</span>
  204. <span class="sd"> Coef: self.cls_loss_gain</span>
  205. <span class="sd"> L_boxes:</span>
  206. <span class="sd"> for anchors that have suitable ground truths in their grid locations</span>
  207. <span class="sd"> add (1 - IoU), IoU between a predicted box and each GT box, force maximum IoU</span>
  208. <span class="sd"> Coef: self.box_loss_gain</span>
  209. <span class="sd"> L_objectness:</span>
  210. <span class="sd"> for each anchor add BCE to force a prediction of (1 - giou_loss_ratio) + giou_loss_ratio * IoU,</span>
  211. <span class="sd"> IoU between a predicted box and random GT in it</span>
  212. <span class="sd"> Coef: self.obj_loss_gain, loss from each YOLO grid is additionally multiplied by balance = [4.0, 1.0, 0.4]</span>
  213. <span class="sd"> to balance different contributions coming from different numbers of grid cells</span>
  214. <span class="sd"> :param predictions: output from all Yolo levels, each of shape</span>
  215. <span class="sd"> [Batch x Num_Anchors x GridSizeY x GridSizeX x (4 + 1 + Num_classes)]</span>
  216. <span class="sd"> :param targets: [Num_targets x (4 + 2)], values on dim 1 are: image id in a batch, class, box x y w h</span>
  217. <span class="sd"> :param giou_loss_ratio: a coef in L_objectness defining what should be predicted as objecness</span>
  218. <span class="sd"> in a call with a target: can be a value in [IoU, 1] range</span>
  219. <span class="sd"> :return: loss, all losses separately in a detached tensor</span>
  220. <span class="sd"> &quot;&quot;&quot;</span>
  221. <span class="n">device</span> <span class="o">=</span> <span class="n">targets</span><span class="o">.</span><span class="n">device</span>
  222. <span class="n">loss_classification</span><span class="p">,</span> <span class="n">loss_boxes</span><span class="p">,</span> <span class="n">loss_objectivness</span> <span class="o">=</span> \
  223. <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
  224. <span class="n">target_classes</span><span class="p">,</span> <span class="n">target_boxes</span><span class="p">,</span> <span class="n">indices</span><span class="p">,</span> <span class="n">anchors</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">build_targets</span><span class="p">(</span><span class="n">predictions</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span> <span class="c1"># targets</span>
  225. <span class="c1"># Define criteria</span>
  226. <span class="n">BCEcls</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">(</span><span class="n">pos_weight</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">cls_pos_weight</span><span class="p">]))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
  227. <span class="n">BCEobj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">(</span><span class="n">pos_weight</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">obj_pos_weight</span><span class="p">]),</span> <span class="n">reduction</span><span class="o">=</span><span class="s1">&#39;none&#39;</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
  228. <span class="c1"># Focal loss</span>
  229. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">focal_loss_gamma</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
  230. <span class="n">BCEcls</span><span class="p">,</span> <span class="n">BCEobj</span> <span class="o">=</span> <span class="n">FocalLoss</span><span class="p">(</span><span class="n">BCEcls</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">focal_loss_gamma</span><span class="p">),</span> <span class="n">FocalLoss</span><span class="p">(</span><span class="n">BCEobj</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">focal_loss_gamma</span><span class="p">)</span>
  231. <span class="c1"># Losses</span>
  232. <span class="n">num_targets</span> <span class="o">=</span> <span class="mi">0</span>
  233. <span class="n">num_predictions</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span>
  234. <span class="n">balance</span> <span class="o">=</span> <span class="p">[</span><span class="mf">4.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">]</span> <span class="k">if</span> <span class="n">num_predictions</span> <span class="o">==</span> <span class="mi">3</span> <span class="k">else</span> <span class="p">[</span><span class="mf">4.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">]</span> <span class="c1"># P3-5 or P3-6</span>
  235. <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">prediction</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">predictions</span><span class="p">):</span> <span class="c1"># layer index, layer predictions</span>
  236. <span class="n">image</span><span class="p">,</span> <span class="n">anchor</span><span class="p">,</span> <span class="n">grid_y</span><span class="p">,</span> <span class="n">grid_x</span> <span class="o">=</span> <span class="n">indices</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
  237. <span class="n">target_obj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">prediction</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
  238. <span class="n">weight_obj</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">prediction</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
  239. <span class="n">n</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="c1"># number of targets</span>
  240. <span class="k">if</span> <span class="n">n</span><span class="p">:</span>
  241. <span class="n">num_targets</span> <span class="o">+=</span> <span class="n">n</span> <span class="c1"># cumulative targets</span>
  242. <span class="n">ps</span> <span class="o">=</span> <span class="n">prediction</span><span class="p">[</span><span class="n">image</span><span class="p">,</span> <span class="n">anchor</span><span class="p">,</span> <span class="n">grid_y</span><span class="p">,</span> <span class="n">grid_x</span><span class="p">]</span> <span class="c1"># prediction subset corresponding to targets</span>
  243. <span class="c1"># Boxes loss</span>
  244. <span class="n">pxy</span> <span class="o">=</span> <span class="n">ps</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">()</span> <span class="o">*</span> <span class="mf">2.</span> <span class="o">-</span> <span class="mf">0.5</span>
  245. <span class="n">pwh</span> <span class="o">=</span> <span class="p">(</span><span class="n">ps</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:</span><span class="mi">4</span><span class="p">]</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">()</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">anchors</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
  246. <span class="n">pbox</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">pxy</span><span class="p">,</span> <span class="n">pwh</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="c1"># predicted box</span>
  247. <span class="n">iou</span> <span class="o">=</span> <span class="n">calculate_bbox_iou_elementwise</span><span class="p">(</span><span class="n">pbox</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">target_boxes</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">x1y1x2y2</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">CIoU</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  248. <span class="n">loss_boxes</span> <span class="o">+=</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">iou</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span> <span class="c1"># iou loss</span>
  249. <span class="c1"># Objectness loss target</span>
  250. <span class="n">target_obj</span><span class="p">[</span><span class="n">image</span><span class="p">,</span> <span class="n">anchor</span><span class="p">,</span> <span class="n">grid_y</span><span class="p">,</span> <span class="n">grid_x</span><span class="p">]</span> <span class="o">=</span> \
  251. <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">giou_loss_ratio</span><span class="p">)</span> <span class="o">+</span> <span class="n">giou_loss_ratio</span> <span class="o">*</span> <span class="n">iou</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">target_obj</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
  252. <span class="c1"># Weights for weighted objectness</span>
  253. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">cls_obj_weights</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  254. <span class="c1"># NOTE: for grid cells that have a few ground truths with different classes assigned to them</span>
  255. <span class="c1"># objectness weight will be picked randomly from one of these classes</span>
  256. <span class="n">weight_obj</span><span class="p">[</span><span class="n">image</span><span class="p">,</span> <span class="n">anchor</span><span class="p">,</span> <span class="n">grid_y</span><span class="p">,</span> <span class="n">grid_x</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cls_obj_weights</span><span class="p">[</span><span class="n">target_classes</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span>
  257. <span class="c1"># Classification loss</span>
  258. <span class="k">if</span> <span class="n">ps</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mi">6</span><span class="p">:</span> <span class="c1"># cls loss (only if multiple classes)</span>
  259. <span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">full_like</span><span class="p">(</span><span class="n">ps</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">:],</span> <span class="mi">0</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span> <span class="c1"># targets</span>
  260. <span class="n">t</span><span class="p">[</span><span class="nb">range</span><span class="p">(</span><span class="n">n</span><span class="p">),</span> <span class="n">target_classes</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span> <span class="o">=</span> <span class="mi">1</span>
  261. <span class="n">loss_classification</span> <span class="o">+=</span> <span class="n">BCEcls</span><span class="p">(</span><span class="n">ps</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">:],</span> <span class="n">t</span><span class="p">)</span> <span class="c1"># BCE</span>
  262. <span class="c1"># Objectness loss</span>
  263. <span class="n">loss_obj_cur_head</span> <span class="o">=</span> <span class="n">BCEobj</span><span class="p">(</span><span class="n">prediction</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">4</span><span class="p">],</span> <span class="n">target_obj</span><span class="p">)</span>
  264. <span class="n">loss_obj_cur_head</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">loss_obj_cur_head</span> <span class="o">*</span> <span class="n">weight_obj</span> <span class="o">/</span> <span class="n">torch</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">weight_obj</span><span class="p">))</span>
  265. <span class="n">loss_objectivness</span> <span class="o">+=</span> <span class="n">loss_obj_cur_head</span> <span class="o">*</span> <span class="n">balance</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="c1"># obj loss</span>
  266. <span class="n">batch_size</span> <span class="o">=</span> <span class="n">prediction</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="c1"># batch size</span>
  267. <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_boxes</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">box_loss_gain</span> <span class="o">+</span> <span class="n">loss_objectivness</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">obj_loss_gain</span> <span class="o">+</span> <span class="n">loss_classification</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">cls_loss_gain</span>
  268. <span class="c1"># IMPORTANT: box, obj and cls loss are logged scaled by gain in ultralytics</span>
  269. <span class="c1"># and are logged unscaled in our codebase</span>
  270. <span class="k">return</span> <span class="n">loss</span> <span class="o">*</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">loss_boxes</span><span class="p">,</span> <span class="n">loss_objectivness</span><span class="p">,</span> <span class="n">loss_classification</span><span class="p">,</span> <span class="n">loss</span><span class="p">))</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></div></div>
  271. </pre></div>
  272. </div>
  273. </div>
  274. <footer>
  275. <hr/>
  276. <div role="contentinfo">
  277. <p>&#169; Copyright 2021, SuperGradients team.</p>
  278. </div>
  279. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  280. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  281. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  282. </footer>
  283. </div>
  284. </div>
  285. </section>
  286. </div>
  287. <script>
  288. jQuery(function () {
  289. SphinxRtdTheme.Navigation.enable(true);
  290. });
  291. </script>
  292. </body>
  293. </html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.losses.yolox_loss &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.losses.yolox_loss &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -95,10 +97,13 @@
 <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
 <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span>
 
 
 <span class="kn">import</span> <span class="nn">torch</span>
 <span class="kn">import</span> <span class="nn">torch</span>
+<span class="kn">import</span> <span class="nn">torch.distributed</span> <span class="k">as</span> <span class="nn">dist</span>
 <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
 <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
 <span class="kn">from</span> <span class="nn">torch.nn.modules.loss</span> <span class="kn">import</span> <span class="n">_Loss</span>
 <span class="kn">from</span> <span class="nn">torch.nn.modules.loss</span> <span class="kn">import</span> <span class="n">_Loss</span>
 <span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
 <span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
+
 <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
+<span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">torch_version_is_greater_or_equal</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">calculate_bbox_iou_matrix</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">calculate_bbox_iou_matrix</span>
 
 
 <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
 <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
@@ -127,10 +132,9 @@
         <span class="n">supported_losses</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;iou&quot;</span><span class="p">,</span> <span class="s2">&quot;giou&quot;</span><span class="p">]</span>
         <span class="n">supported_losses</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;iou&quot;</span><span class="p">,</span> <span class="s2">&quot;giou&quot;</span><span class="p">]</span>
         <span class="n">supported_reductions</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;mean&quot;</span><span class="p">,</span> <span class="s2">&quot;sum&quot;</span><span class="p">,</span> <span class="s2">&quot;none&quot;</span><span class="p">]</span>
         <span class="n">supported_reductions</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;mean&quot;</span><span class="p">,</span> <span class="s2">&quot;sum&quot;</span><span class="p">,</span> <span class="s2">&quot;none&quot;</span><span class="p">]</span>
         <span class="k">if</span> <span class="n">loss_type</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">supported_losses</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">loss_type</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">supported_losses</span><span class="p">:</span>
-            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Illegal loss_type value: &quot;</span> <span class="o">+</span> <span class="n">loss_type</span> <span class="o">+</span> <span class="s1">&#39;, expected one of: &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">supported_losses</span><span class="p">))</span>
+            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Illegal loss_type value: &quot;</span> <span class="o">+</span> <span class="n">loss_type</span> <span class="o">+</span> <span class="s2">&quot;, expected one of: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">supported_losses</span><span class="p">))</span>
         <span class="k">if</span> <span class="n">reduction</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">supported_reductions</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">reduction</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">supported_reductions</span><span class="p">:</span>
-            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
-                <span class="s2">&quot;Illegal reduction value: &quot;</span> <span class="o">+</span> <span class="n">reduction</span> <span class="o">+</span> <span class="s1">&#39;, expected one of: &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">supported_reductions</span><span class="p">))</span>
+            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Illegal reduction value: &quot;</span> <span class="o">+</span> <span class="n">reduction</span> <span class="o">+</span> <span class="s2">&quot;, expected one of: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">supported_reductions</span><span class="p">))</span>
 
 
     <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pred</span><span class="p">,</span> <span class="n">target</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pred</span><span class="p">,</span> <span class="n">target</span><span class="p">):</span>
         <span class="k">assert</span> <span class="n">pred</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">target</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
         <span class="k">assert</span> <span class="n">pred</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">target</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
@@ -149,7 +153,7 @@
         <span class="n">iou</span> <span class="o">=</span> <span class="p">(</span><span class="n">area_i</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">area_u</span> <span class="o">+</span> <span class="mf">1e-16</span><span class="p">)</span>
         <span class="n">iou</span> <span class="o">=</span> <span class="p">(</span><span class="n">area_i</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">area_u</span> <span class="o">+</span> <span class="mf">1e-16</span><span class="p">)</span>
 
 
         <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_type</span> <span class="o">==</span> <span class="s2">&quot;iou&quot;</span><span class="p">:</span>
         <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_type</span> <span class="o">==</span> <span class="s2">&quot;iou&quot;</span><span class="p">:</span>
-            <span class="n">loss</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">iou</span> <span class="o">**</span> <span class="mi">2</span>
+            <span class="n">loss</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">iou</span><span class="o">**</span><span class="mi">2</span>
         <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_type</span> <span class="o">==</span> <span class="s2">&quot;giou&quot;</span><span class="p">:</span>
         <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_type</span> <span class="o">==</span> <span class="s2">&quot;giou&quot;</span><span class="p">:</span>
             <span class="n">c_tl</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">((</span><span class="n">pred</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">pred</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">),</spa
             <span class="n">c_tl</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">((</span><span class="n">pred</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">pred</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">),</spa
             <span class="n">c_br</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="n">pred</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">pred</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">),</spa
             <span class="n">c_br</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">((</span><span class="n">pred</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">pred</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">),</spa
@@ -165,7 +169,7 @@
         <span class="k">return</span> <span class="n">loss</span>
         <span class="k">return</span> <span class="n">loss</span>
 
 
 
 
-<div class="viewcode-block" id="YoloXDetectionLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoloXDetectionLoss">[docs]</a><span class="k">class</span> <span class="nc">YoloXDetectionLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
+<div class="viewcode-block" id="YoloXDetectionLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.YoloXDetectionLoss">[docs]</a><span class="k">class</span> <span class="nc">YoloXDetectionLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Calculate YOLOX loss:</span>
 <span class="sd">    Calculate YOLOX loss:</span>
 <span class="sd">    L = L_objectivness + L_iou + L_classification + 1[use_l1]*L_l1</span>
 <span class="sd">    L = L_objectivness + L_iou + L_classification + 1[use_l1]*L_l1</span>
@@ -200,8 +204,7 @@
 
 
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 
 
-    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">strides</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">use_l1</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><sp
-                 <span class="n">iou_type</span><span class="o">=</span><span class="s1">&#39;iou&#39;</span><span class="p">):</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">strides</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">use_l1</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><sp
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">grids</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">)]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">strides</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">grids</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">)]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">strides</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">strides</span> <span class="o">=</span> <span class="n">strides</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">strides</span> <span class="o">=</span> <span class="n">strides</span>
@@ -213,7 +216,16 @@
         <span class="bp">self</span><span class="o">.</span><span class="n">bcewithlog_loss</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">(</span><span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">bcewithlog_loss</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">(</span><span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">iou_loss</span> <span class="o">=</span> <span class="n">IOUloss</span><span class="p">(</span><span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">,</span> <span class="n">loss_type</span><span class="o">=</span><span class="n">iou_type</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">iou_loss</span> <span class="o">=</span> <span class="n">IOUloss</span><span class="p">(</span><span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">,</span> <span class="n">loss_type</span><span class="o">=</span><span class="n">iou_type</span><span class="p">)</span>
 
 
-<div class="viewcode-block" id="YoloXDetectionLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoloXDetectionLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model_output</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">list</span><span clas
+    <span class="nd">@property</span>
+    <span class="k">def</span> <span class="nf">component_names</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+        <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">        Component names for logging during training.</span>
+<span class="sd">        These correspond to 2nd item in the tuple returned in self.forward(...).</span>
+<span class="sd">        See super_gradients.Trainer.train() docs for more info.</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+        <span class="k">return</span> <span class="p">[</span><span class="s2">&quot;iou&quot;</span><span class="p">,</span> <span class="s2">&quot;obj&quot;</span><span class="p">,</span> <span class="s2">&quot;cls&quot;</span><span class="p">,</span> <span class="s2">&quot;l1&quot;</span><span class="p">,</span> <span class="s2">&quot;num_fg&quot;</span><span class="p">,</span> <span class="s2">&quot;Loss&quot;</span><span class="p">]</span>
+
+<div class="viewcode-block" id="YoloXDetectionLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.YoloXDetectionLoss.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model_output</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">list</span><span class="p">,
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        :param model_output: Union[list, Tuple[torch.Tensor, List]]:</span>
 <span class="sd">        :param model_output: Union[list, Tuple[torch.Tensor, List]]:</span>
 <span class="sd">             When list-</span>
 <span class="sd">             When list-</span>
@@ -241,11 +253,14 @@
 <span class="sd">        :param ny: int: cells along the y axis (default=20)</span>
 <span class="sd">        :param ny: int: cells along the y axis (default=20)</span>
 <span class="sd">        :return: torch.tensor of xy coordinates of size (1,1,nx,ny,2)</span>
 <span class="sd">        :return: torch.tensor of xy coordinates of size (1,1,nx,ny,2)</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
-        <span class="n">yv</span><span class="p">,</span> <span class="n">xv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">meshgrid</span><span class="p">([</span><span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">ny</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">nx</s
+        <span class="k">if</span> <span class="n">torch_version_is_greater_or_equal</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">10</span><span class="p">):</span>
+            <span class="c1"># https://github.com/pytorch/pytorch/issues/50276</span>
+            <span class="n">yv</span><span class="p">,</span> <span class="n">xv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">meshgrid</span><span class="p">([</span><span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">ny</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">n
+        <span class="k">else</span><span class="p">:</span>
+            <span class="n">yv</span><span class="p">,</span> <span class="n">xv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">meshgrid</span><span class="p">([</span><span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">ny</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">n
         <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">xv</span><span class="p">,</span> <span class="n">yv</span><span class="p">),</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">ny</span><span
         <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">xv</span><span class="p">,</span> <span class="n">yv</span><span class="p">),</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">ny</span><span
 
 
-    <span class="k">def</span> <span class="nf">_compute_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">targets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><
-            <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]:</span>
+    <span class="k">def</span> <span class="nf">_compute_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">targets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        :param predictions:     output from all Yolo levels, each of shape</span>
 <span class="sd">        :param predictions:     output from all Yolo levels, each of shape</span>
 <span class="sd">                                [Batch x 1 x GridSizeY x GridSizeX x (4 + 1 + Num_classes)]</span>
 <span class="sd">                                [Batch x 1 x GridSizeY x GridSizeX x (4 + 1 + Num_classes)]</span>
@@ -267,7 +282,7 @@
         <span class="n">obj_targets</span> <span class="o">=</span> <span class="p">[]</span>
         <span class="n">obj_targets</span> <span class="o">=</span> <span class="p">[]</span>
         <span class="n">fg_masks</span> <span class="o">=</span> <span class="p">[]</span>
         <span class="n">fg_masks</span> <span class="o">=</span> <span class="p">[]</span>
 
 
-        <span class="n">num_fg</span><span class="p">,</span> <span class="n">num_gts</span> <span class="o">=</span> <span class="mf">0.</span><span class="p">,</span> <span class="mf">0.</span>
+        <span class="n">num_fg</span><span class="p">,</span> <span class="n">num_gts</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.0</span>
 
 
         <span class="k">for</span> <span class="n">image_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">transformed_outputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
         <span class="k">for</span> <span class="n">image_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">transformed_outputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
             <span class="n">labels_im</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="n">targets</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">image_idx</span><span class="p">]</span>
             <span class="n">labels_im</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="n">targets</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">image_idx</span><span class="p">]</span>
@@ -287,21 +302,42 @@
 
 
                 <span class="k">try</span><span class="p">:</span>
                 <span class="k">try</span><span class="p">:</span>
                     <span class="c1"># assign cells to ground truths, at most one GT per cell</span>
                     <span class="c1"># assign cells to ground truths, at most one GT per cell</span>
-                    <span class="n">gt_matched_classes</span><span class="p">,</span> <span class="n">fg_mask</span><span class="p">,</span> <span class="n">pred_ious_this_matching</span><span class="p">,</span> <span class="n">matched_gt_inds</span><span class="p">,</span> <span class="n">num_fg_img</span> <span class="o">=</span> \
-                        <span class="bp">self</span><span class="o">.</span><span class="n">get_assignments</span><span class="p">(</span><span class="n">image_idx</span><span class="p">,</span> <span class="n">num_gt</span><span class="p">,</span> <span class="n">total_num_anchors</span><span class="p">,</span> <span class="n">gt_bboxes_per_image</span><span class="p">,</span>
-                                             <span class="n">gt_classes</span><span class="p">,</span> <span class="n">bboxes_preds_per_image</span><span class="p">,</span>
-                                             <span class="n">expanded_strides</span><span class="p">,</span> <span class="n">x_shifts</span><span class="p">,</span> <span class="n">y_shifts</span><span class="p">,</span> <span class="n">cls_preds</span><span class="p">,</span> <span class="n">obj_preds</span><span class="p">)</span>
+                    <span class="n">gt_matched_classes</span><span class="p">,</span> <span class="n">fg_mask</span><span class="p">,</span> <span class="n">pred_ious_this_matching</span><span class="p">,</span> <span class="n">matched_gt_inds</span><span class="p">,</span> <span class="n">num_fg_img</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_assignments</span><span class="p">(</span>
+                        <span class="n">image_idx</span><span class="p">,</span>
+                        <span class="n">num_gt</span><span class="p">,</span>
+                        <span class="n">total_num_anchors</span><span class="p">,</span>
+                        <span class="n">gt_bboxes_per_image</span><span class="p">,</span>
+                        <span class="n">gt_classes</span><span class="p">,</span>
+                        <span class="n">bboxes_preds_per_image</span><span class="p">,</span>
+                        <span class="n">expanded_strides</span><span class="p">,</span>
+                        <span class="n">x_shifts</span><span class="p">,</span>
+                        <span class="n">y_shifts</span><span class="p">,</span>
+                        <span class="n">cls_preds</span><span class="p">,</span>
+                        <span class="n">obj_preds</span><span class="p">,</span>
+                    <span class="p">)</span>
 
 
                 <span class="c1"># TODO: CHECK IF ERROR IS CUDA OUT OF MEMORY</span>
                 <span class="c1"># TODO: CHECK IF ERROR IS CUDA OUT OF MEMORY</span>
                 <span class="k">except</span> <span class="ne">RuntimeError</span><span class="p">:</span>
                 <span class="k">except</span> <span class="ne">RuntimeError</span><span class="p">:</span>
-                    <span class="n">logging</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="s2">&quot;OOM RuntimeError is raised due to the huge memory cost during label assignment. </span><span class="se">\</span>
+                    <span class="n">logging</span><span class="o">.</span><span class="n">error</span><span class="p">(</span>
+                        <span class="s2">&quot;OOM RuntimeError is raised due to the huge memory cost during label assignment. </span><span class="se">\</span>
 <span class="s2">                                   CPU mode is applied in this batch. If you want to avoid this issue, </span><span class="se">\</span>
 <span class="s2">                                   CPU mode is applied in this batch. If you want to avoid this issue, </span><span class="se">\</span>
-<span class="s2">                                   try to reduce the batch size or image size.&quot;</span><span class="p">)</span>
+<span class="s2">                                   try to reduce the batch size or image size.&quot;</span>
+                    <span class="p">)</span>
                     <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">empty_cache</span><span class="p">()</span>
                     <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">empty_cache</span><span class="p">()</span>
-                    <span class="n">gt_matched_classes</span><span class="p">,</span> <span class="n">fg_mask</span><span class="p">,</span> <span class="n">pred_ious_this_matching</span><span class="p">,</span> <span class="n">matched_gt_inds</span><span class="p">,</span> <span class="n">num_fg_img</span> <span class="o">=</span> \
-                        <span class="bp">self</span><span class="o">.</span><span class="n">get_assignments</span><span class="p">(</span><span class="n">image_idx</span><span class="p">,</span> <span class="n">num_gt</span><span class="p">,</span> <span class="n">total_num_anchors</span><span class="p">,</span> <span class="n">gt_bboxes_per_image</span><span class="p">,</span>
-                                             <span class="n">gt_classes</span><span class="p">,</span> <span class="n">bboxes_preds_per_image</span><span class="p">,</span>
-                                             <span class="n">expanded_strides</span><span class="p">,</span> <span class="n">x_shifts</span><span class="p">,</span> <span class="n">y_shifts</span><span class="p">,</span> <span class="n">cls_preds</span><span class="p">,</span> <span class="n">obj_preds</span><span class="p">,</span> <span class="s1">&#39;cpu&#39;</span><span class="p">)</span>
+                    <span class="n">gt_matched_classes</span><span class="p">,</span> <span class="n">fg_mask</span><span class="p">,</span> <span class="n">pred_ious_this_matching</span><span class="p">,</span> <span class="n">matched_gt_inds</span><span class="p">,</span> <span class="n">num_fg_img</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_assignments</span><span class="p">(</span>
+                        <span class="n">image_idx</span><span class="p">,</span>
+                        <span class="n">num_gt</span><span class="p">,</span>
+                        <span class="n">total_num_anchors</span><span class="p">,</span>
+                        <span class="n">gt_bboxes_per_image</span><span class="p">,</span>
+                        <span class="n">gt_classes</span><span class="p">,</span>
+                        <span class="n">bboxes_preds_per_image</span><span class="p">,</span>
+                        <span class="n">expanded_strides</span><span class="p">,</span>
+                        <span class="n">x_shifts</span><span class="p">,</span>
+                        <span class="n">y_shifts</span><span class="p">,</span>
+                        <span class="n">cls_preds</span><span class="p">,</span>
+                        <span class="n">obj_preds</span><span class="p">,</span>
+                        <span class="s2">&quot;cpu&quot;</span><span class="p">,</span>
+                    <span class="p">)</span>
 
 
                 <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">empty_cache</span><span class="p">()</span>
                 <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">empty_cache</span><span class="p">()</span>
                 <span class="n">num_fg</span> <span class="o">+=</span> <span class="n">num_fg_img</span>
                 <span class="n">num_fg</span> <span class="o">+=</span> <span class="n">num_fg_img</span>
@@ -310,9 +346,13 @@
                 <span class="n">obj_target</span> <span class="o">=</span> <span class="n">fg_mask</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
                 <span class="n">obj_target</span> <span class="o">=</span> <span class="n">fg_mask</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
                 <span class="n">reg_target</span> <span class="o">=</span> <span class="n">gt_bboxes_per_image</span><span class="p">[</span><span class="n">matched_gt_inds</span><span class="p">]</span>
                 <span class="n">reg_target</span> <span class="o">=</span> <span class="n">gt_bboxes_per_image</span><span class="p">[</span><span class="n">matched_gt_inds</span><span class="p">]</span>
                 <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_l1</span><span class="p">:</span>
                 <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_l1</span><span class="p">:</span>
-                    <span class="n">l1_target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_l1_target</span><span class="p">(</span><span class="n">transformed_outputs</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">((</span><span class="n">num_fg_img</span><span class="p">,</span> <span class="mi">4</span><span class="p">)),</span>
-                                                   <span class="n">gt_bboxes_per_image</span><span class="p">[</span><span class="n">matched_gt_inds</span><span class="p">],</span> <span class="n">expanded_strides</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">fg_mask</span><span class="p">],</span>
-                                                   <span class="n">x_shifts</span><span class="o">=</span><span class="n">x_shifts</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">fg_mask</span><span class="p">],</span> <span class="n">y_shifts</span><span class="o">=</span><span class="n">y_shifts</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">fg_mask</span><span class="p">])</span>
+                    <span class="n">l1_target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_l1_target</span><span class="p">(</span>
+                        <span class="n">transformed_outputs</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">((</span><span class="n">num_fg_img</span><span class="p">,</span> <span class="mi">4</span><span class="p">)),</span>
+                        <span class="n">gt_bboxes_per_image</span><span class="p">[</span><span class="n">matched_gt_inds</span><span class="p">],</span>
+                        <span class="n">expanded_strides</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">fg_mask</span><span class="p">],</span>
+                        <span class="n">x_shifts</span><span class="o">=</span><span class="n">x_shifts</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">fg_mask</span><span class="p">],</span>
+                        <span class="n">y_shifts</span><span class="o">=</span><span class="n">y_shifts</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">fg_mask</span><span class="p">],</span>
+                    <span class="p">)</span>
 
 
             <span class="c1"># collect targets for all loss terms over the whole batch</span>
             <span class="c1"># collect targets for all loss terms over the whole batch</span>
             <span class="n">cls_targets</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">cls_target</span><span class="p">)</span>
             <span class="n">cls_targets</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">cls_target</span><span class="p">)</span>
@@ -343,13 +383,21 @@
         <span class="n">reg_weight</span> <span class="o">=</span> <span class="mf">5.0</span>
         <span class="n">reg_weight</span> <span class="o">=</span> <span class="mf">5.0</span>
         <span class="n">loss</span> <span class="o">=</span> <span class="n">reg_weight</span> <span class="o">*</span> <span class="n">loss_iou</span> <span class="o">+</span> <span class="n">loss_obj</span> <span class="o">+</span> <span class="n">loss_cls</span> <span class="o">+</span> <span class="n">loss_l1</span>
         <span class="n">loss</span> <span class="o">=</span> <span class="n">reg_weight</span> <span class="o">*</span> <span class="n">loss_iou</span> <span class="o">+</span> <span class="n">loss_obj</span> <span class="o">+</span> <span class="n">loss_cls</span> <span class="o">+</span> <span class="n">loss_l1</span>
 
 
-        <span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">loss_iou</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">loss_obj</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span cl
-                                <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">loss_l1</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">loss</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
-                                <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">num_fg</span> <span class="o">/</span> <span class="nb">max</span><span class="p">(</span><span class="n">num_gts</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</spa
-                                <span class="n">loss</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)))</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
-
-<div class="viewcode-block" id="YoloXDetectionLoss.prepare_predictions"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoloXDetectionLoss.prepare_predictions">[docs]</a>    <span class="k">def</span> <span class="nf">prepare_predictions</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><spa
-            <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class
+        <span class="k">return</span> <span class="p">(</span>
+            <span class="n">loss</span><span class="p">,</span>
+            <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
+                <span class="p">(</span>
+                    <span class="n">loss_iou</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
+                    <span class="n">loss_obj</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
+                    <span class="n">loss_cls</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
+                    <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">loss_l1</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">loss</span><span class="o">.</span><span class="n">device</span><span class="p">),</span>
+                    <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">num_fg</span> <span class="o">/</span> <span class="nb">max</span><span class="p">(</span><span class="n">num_gts</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span clas
+                    <span class="n">loss</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
+                <span class="p">)</span>
+            <span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">(),</span>
+        <span class="p">)</span>
+
+<div class="viewcode-block" id="YoloXDetectionLoss.prepare_predictions"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.YoloXDetectionLoss.prepare_predictions">[docs]</a>    <span class="k">def</span> <span class="nf">prepare_predictions</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        Convert raw outputs of the network into a format that merges outputs from all levels</span>
 <span class="sd">        Convert raw outputs of the network into a format that merges outputs from all levels</span>
 <span class="sd">        :param predictions:     output from all Yolo levels, each of shape</span>
 <span class="sd">        :param predictions:     output from all Yolo levels, each of shape</span>
@@ -414,7 +462,7 @@
 
 
         <span class="k">return</span> <span class="n">x_shifts</span><span class="p">,</span> <span class="n">y_shifts</span><span class="p">,</span> <span class="n">expanded_strides</span><span class="p">,</span> <span class="n">transformed_outputs</span><span class="p">,</span> <span class="n">raw_outputs</span></div>
         <span class="k">return</span> <span class="n">x_shifts</span><span class="p">,</span> <span class="n">y_shifts</span><span class="p">,</span> <span class="n">expanded_strides</span><span class="p">,</span> <span class="n">transformed_outputs</span><span class="p">,</span> <span class="n">raw_outputs</span></div>
 
 
-<div class="viewcode-block" id="YoloXDetectionLoss.get_l1_target"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoloXDetectionLoss.get_l1_target">[docs]</a>    <span class="k">def</span> <span class="nf">get_l1_target</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">l1_target</span><span class="p">,</span> <span class="n">gt</span><span class="p">,</span> <span class="n">stride</s
+<div class="viewcode-block" id="YoloXDetectionLoss.get_l1_target"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.YoloXDetectionLoss.get_l1_target">[docs]</a>    <span class="k">def</span> <span class="nf">get_l1_target</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">l1_target</span><span class="p">,</span> <span class="n">gt</span><span class="p">,</span> <span class="n">stride</span><sp
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        :param l1_target:   tensor of zeros of shape [Num_cell_gt_pairs x 4]</span>
 <span class="sd">        :param l1_target:   tensor of zeros of shape [Num_cell_gt_pairs x 4]</span>
 <span class="sd">        :param gt:          targets in coordinates [Num_cell_gt_pairs x (4 + 1 + num_classes)]</span>
 <span class="sd">        :param gt:          targets in coordinates [Num_cell_gt_pairs x (4 + 1 + num_classes)]</span>
@@ -427,10 +475,24 @@
         <span class="n">l1_target</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">gt</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">/</span> <span class="n">stride</span> <span class="o">+</span> <span class="n">eps</span><span class="p">)</span>
         <span class="n">l1_target</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">gt</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">/</span> <span class="n">stride</span> <span class="o">+</span> <span class="n">eps</span><span class="p">)</span>
         <span class="k">return</span> <span class="n">l1_target</span></div>
         <span class="k">return</span> <span class="n">l1_target</span></div>
 
 
-<div class="viewcode-block" id="YoloXDetectionLoss.get_assignments"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoloXDetectionLoss.get_assignments">[docs]</a>    <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
-    <span class="k">def</span> <span class="nf">get_assignments</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">image_idx</span><span class="p">,</span> <span class="n">num_gt</span><span class="p">,</span> <span class="n">total_num_anchors</span><span class="p">,</span> <span class="n">gt_bboxes_per_image</span><span class="p">,</span> <span class="n">gt_classes</span><span class="p">,</span>
-                        <span class="n">bboxes_preds_per_image</span><span class="p">,</span> <span class="n">expanded_strides</span><span class="p">,</span> <span class="n">x_shifts</span><span class="p">,</span> <span class="n">y_shifts</span><span class="p">,</span> <span class="n">cls_preds</span><span class="p">,</span>
-                        <span class="n">obj_preds</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;gpu&quot;</span><span class="p">,</span> <span class="n">ious_loss_cost_coeff</span><span class="o">=</span><span class="mf">3.0</span><span class="p">,</span> <span class="n">outside_boxes_and_center_cost_coeff</span><span class="o">=</span><span class="mf">100000.0</span><span class="p">):</span>
+<div class="viewcode-block" id="YoloXDetectionLoss.get_assignments"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.YoloXDetectionLoss.get_assignments">[docs]</a>    <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
+    <span class="k">def</span> <span class="nf">get_assignments</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">image_idx</span><span class="p">,</span>
+        <span class="n">num_gt</span><span class="p">,</span>
+        <span class="n">total_num_anchors</span><span class="p">,</span>
+        <span class="n">gt_bboxes_per_image</span><span class="p">,</span>
+        <span class="n">gt_classes</span><span class="p">,</span>
+        <span class="n">bboxes_preds_per_image</span><span class="p">,</span>
+        <span class="n">expanded_strides</span><span class="p">,</span>
+        <span class="n">x_shifts</span><span class="p">,</span>
+        <span class="n">y_shifts</span><span class="p">,</span>
+        <span class="n">cls_preds</span><span class="p">,</span>
+        <span class="n">obj_preds</span><span class="p">,</span>
+        <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;gpu&quot;</span><span class="p">,</span>
+        <span class="n">ious_loss_cost_coeff</span><span class="o">=</span><span class="mf">3.0</span><span class="p">,</span>
+        <span class="n">outside_boxes_and_center_cost_coeff</span><span class="o">=</span><span class="mf">100000.0</span><span class="p">,</span>
+    <span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        Match cells to ground truth:</span>
 <span class="sd">        Match cells to ground truth:</span>
 <span class="sd">            * at most 1 GT per cell</span>
 <span class="sd">            * at most 1 GT per cell</span>
@@ -464,8 +526,7 @@
             <span class="n">y_shifts</span> <span class="o">=</span> <span class="n">y_shifts</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
             <span class="n">y_shifts</span> <span class="o">=</span> <span class="n">y_shifts</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
 
 
         <span class="c1"># create a mask for foreground cells</span>
         <span class="c1"># create a mask for foreground cells</span>
-        <span class="n">fg_mask</span><span class="p">,</span> <span class="n">is_in_boxes_and_center</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_in_boxes_info</span><span class="p">(</span><span class="n">gt_bboxes_per_image</span><span class="p">,</span> <span class="n">expanded_strides</span><span class="p">,</span>
-                                                                 <span class="n">x_shifts</span><span class="p">,</span> <span class="n">y_shifts</span><span class="p">,</span> <span class="n">total_num_anchors</span><span class="p">,</span> <span class="n">num_gt</span><span class="p">)</span>
+        <span class="n">fg_mask</span><span class="p">,</span> <span class="n">is_in_boxes_and_center</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_in_boxes_info</span><span class="p">(</span><span class="n">gt_bboxes_per_image</span><span class="p">,</span> <span class="n">expanded_strides</span><span class="p">,</span> <span class="n">x_shifts</span><span class="p">,</span> <span class="n">y_shifts</span><span class="p">,</span> <span 
 
 
         <span class="n">bboxes_preds_per_image</span> <span class="o">=</span> <span class="n">bboxes_preds_per_image</span><span class="p">[</span><span class="n">fg_mask</span><span class="p">]</span>
         <span class="n">bboxes_preds_per_image</span> <span class="o">=</span> <span class="n">bboxes_preds_per_image</span><span class="p">[</span><span class="n">fg_mask</span><span class="p">]</span>
         <span class="n">cls_preds_</span> <span class="o">=</span> <span class="n">cls_preds</span><span class="p">[</span><span class="n">image_idx</span><span class="p">][</span><span class="n">fg_mask</span><span class="p">]</span>
         <span class="n">cls_preds_</span> <span class="o">=</span> <span class="n">cls_preds</span><span class="p">[</span><span class="n">image_idx</span><span class="p">][</span><span class="n">fg_mask</span><span class="p">]</span>
@@ -490,12 +551,10 @@
             <span class="n">pair_wise_cls_loss</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">binary_cross_entropy</span><span class="p">(</span><span class="n">cls_preds_</span><span class="o">.</span><span class="n">sqrt_</span><span class="p">(),</span> <span class="n">gt_cls_per_image</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">)</span><span cla
             <span class="n">pair_wise_cls_loss</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">binary_cross_entropy</span><span class="p">(</span><span class="n">cls_preds_</span><span class="o">.</span><span class="n">sqrt_</span><span class="p">(),</span> <span class="n">gt_cls_per_image</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">)</span><span cla
         <span class="k">del</span> <span class="n">cls_preds_</span>
         <span class="k">del</span> <span class="n">cls_preds_</span>
 
 
-        <span class="n">cost</span> <span class="o">=</span> <span class="n">pair_wise_cls_loss</span> <span class="o">+</span> <span class="n">ious_loss_cost_coeff</span> <span class="o">*</span> <span class="n">pair_wise_ious_loss</span> <span class="o">+</span> <span class="n">outside_boxes_and_center_cost_coeff</span> <span class="o">*</span> <span class="p">(</span>
-            <span class="o">~</span><span class="n">is_in_boxes_and_center</span><span class="p">)</span>
+        <span class="n">cost</span> <span class="o">=</span> <span class="n">pair_wise_cls_loss</span> <span class="o">+</span> <span class="n">ious_loss_cost_coeff</span> <span class="o">*</span> <span class="n">pair_wise_ious_loss</span> <span class="o">+</span> <span class="n">outside_boxes_and_center_cost_coeff</span> <span class="o">*</span> <span class="p">(</span><span class="o">~</span><span class="n">is_in_boxes_and_center</span><span class="p">)</span>
 
 
         <span class="c1"># further filter foregrounds: create pairs between cells and ground truth, based on cost and IoUs</span>
         <span class="c1"># further filter foregrounds: create pairs between cells and ground truth, based on cost and IoUs</span>
-        <span class="n">num_fg</span><span class="p">,</span> <span class="n">gt_matched_classes</span><span class="p">,</span> <span class="n">pred_ious_this_matching</span><span class="p">,</span> <span class="n">matched_gt_inds</span> <span class="o">=</span> \
-            <span class="bp">self</span><span class="o">.</span><span class="n">dynamic_k_matching</span><span class="p">(</span><span class="n">cost</span><span class="p">,</span> <span class="n">pair_wise_ious</span><span class="p">,</span> <span class="n">gt_classes</span><span class="p">,</span> <span class="n">num_gt</span><span class="p">,</span> <span class="n">fg_mask</span><span class="p">)</span>
+        <span class="n">num_fg</span><span class="p">,</span> <span class="n">gt_matched_classes</span><span class="p">,</span> <span class="n">pred_ious_this_matching</span><span class="p">,</span> <span class="n">matched_gt_inds</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dynamic_k_matching</span><span class="p">(</span><span class="n">cost</span><span class="p">,</span> <span class="n">pair_wise_ious</span><span class="p">,</span> <span
         <span class="c1"># discard tensors related to cost</span>
         <span class="c1"># discard tensors related to cost</span>
         <span class="k">del</span> <span class="n">pair_wise_cls_loss</span><span class="p">,</span> <span class="n">cost</span><span class="p">,</span> <span class="n">pair_wise_ious</span><span class="p">,</span> <span class="n">pair_wise_ious_loss</span>
         <span class="k">del</span> <span class="n">pair_wise_cls_loss</span><span class="p">,</span> <span class="n">cost</span><span class="p">,</span> <span class="n">pair_wise_ious</span><span class="p">,</span> <span class="n">pair_wise_ious_loss</span>
 
 
@@ -507,7 +566,7 @@
 
 
         <span class="k">return</span> <span class="n">gt_matched_classes</span><span class="p">,</span> <span class="n">fg_mask</span><span class="p">,</span> <span class="n">pred_ious_this_matching</span><span class="p">,</span> <span class="n">matched_gt_inds</span><span class="p">,</span> <span class="n">num_fg</span></div>
         <span class="k">return</span> <span class="n">gt_matched_classes</span><span class="p">,</span> <span class="n">fg_mask</span><span class="p">,</span> <span class="n">pred_ious_this_matching</span><span class="p">,</span> <span class="n">matched_gt_inds</span><span class="p">,</span> <span class="n">num_fg</span></div>
 
 
-<div class="viewcode-block" id="YoloXDetectionLoss.get_in_boxes_info"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoloXDetectionLoss.get_in_boxes_info">[docs]</a>    <span class="k">def</span> <span class="nf">get_in_boxes_info</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gt_bboxes_per_image</span><span class="p">,</span> <span class="n">expanded_strides</span><span class="p
+<div class="viewcode-block" id="YoloXDetectionLoss.get_in_boxes_info"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.YoloXDetectionLoss.get_in_boxes_info">[docs]</a>    <span class="k">def</span> <span class="nf">get_in_boxes_info</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gt_bboxes_per_image</span><span class="p">,</span> <span class="n">expanded_strides</span><span class="p">,</sp
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        Create a mask for all cells, mask in only foreground: cells that have a center located:</span>
 <span class="sd">        Create a mask for all cells, mask in only foreground: cells that have a center located:</span>
 <span class="sd">            * withing a GT box;</span>
 <span class="sd">            * withing a GT box;</span>
@@ -561,14 +620,18 @@
         <span class="c1"># FIND CELL CENTERS THAT ARE WITHIN +- self.center_sampling_radius CELLS FROM GROUND TRUTH BOXES CENTERS</span>
         <span class="c1"># FIND CELL CENTERS THAT ARE WITHIN +- self.center_sampling_radius CELLS FROM GROUND TRUTH BOXES CENTERS</span>
 
 
         <span class="c1"># define fake boxes: instead of ground truth boxes step +- self.center_sampling_radius from their centers</span>
         <span class="c1"># define fake boxes: instead of ground truth boxes step +- self.center_sampling_radius from their centers</span>
-        <span class="n">gt_bboxes_per_image_l</span> <span class="o">=</span> <span class="p">((</span><span class="n">gt_bboxes_per_image</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">total
-                                 <span class="bp">self</span><span class="o">.</span><span class="n">center_sampling_radius</span> <span class="o">*</span> <span class="n">expanded_strides_per_image</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
-        <span class="n">gt_bboxes_per_image_r</span> <span class="o">=</span> <span class="p">((</span><span class="n">gt_bboxes_per_image</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">total
-                                 <span class="bp">self</span><span class="o">.</span><span class="n">center_sampling_radius</span> <span class="o">*</span> <span class="n">expanded_strides_per_image</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
-        <span class="n">gt_bboxes_per_image_t</span> <span class="o">=</span> <span class="p">((</span><span class="n">gt_bboxes_per_image</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">total
-                                 <span class="bp">self</span><span class="o">.</span><span class="n">center_sampling_radius</span> <span class="o">*</span> <span class="n">expanded_strides_per_image</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
-        <span class="n">gt_bboxes_per_image_b</span> <span class="o">=</span> <span class="p">((</span><span class="n">gt_bboxes_per_image</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">total
-                                 <span class="bp">self</span><span class="o">.</span><span class="n">center_sampling_radius</span> <span class="o">*</span> <span class="n">expanded_strides_per_image</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
+        <span class="n">gt_bboxes_per_image_l</span> <span class="o">=</span> <span class="p">(</span><span class="n">gt_bboxes_per_image</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span>
+            <span class="mi">1</span><span class="p">,</span> <span class="n">total_num_anchors</span>
+        <span class="p">)</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">center_sampling_radius</span> <span class="o">*</span> <span class="n">expanded_strides_per_image</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
+        <span class="n">gt_bboxes_per_image_r</span> <span class="o">=</span> <span class="p">(</span><span class="n">gt_bboxes_per_image</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span>
+            <span class="mi">1</span><span class="p">,</span> <span class="n">total_num_anchors</span>
+        <span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">center_sampling_radius</span> <span class="o">*</span> <span class="n">expanded_strides_per_image</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
+        <span class="n">gt_bboxes_per_image_t</span> <span class="o">=</span> <span class="p">(</span><span class="n">gt_bboxes_per_image</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span>
+            <span class="mi">1</span><span class="p">,</span> <span class="n">total_num_anchors</span>
+        <span class="p">)</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">center_sampling_radius</span> <span class="o">*</span> <span class="n">expanded_strides_per_image</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
+        <span class="n">gt_bboxes_per_image_b</span> <span class="o">=</span> <span class="p">(</span><span class="n">gt_bboxes_per_image</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span>
+            <span class="mi">1</span><span class="p">,</span> <span class="n">total_num_anchors</span>
+        <span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">center_sampling_radius</span> <span class="o">*</span> <span class="n">expanded_strides_per_image</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
 
 
         <span class="n">c_l</span> <span class="o">=</span> <span class="n">x_centers_per_image</span> <span class="o">-</span> <span class="n">gt_bboxes_per_image_l</span>
         <span class="n">c_l</span> <span class="o">=</span> <span class="n">x_centers_per_image</span> <span class="o">-</span> <span class="n">gt_bboxes_per_image_l</span>
         <span class="n">c_r</span> <span class="o">=</span> <span class="n">gt_bboxes_per_image_r</span> <span class="o">-</span> <span class="n">x_centers_per_image</span>
         <span class="n">c_r</span> <span class="o">=</span> <span class="n">gt_bboxes_per_image_r</span> <span class="o">-</span> <span class="n">x_centers_per_image</span>
@@ -582,10 +645,10 @@
         <span class="n">is_in_boxes_anchor</span> <span class="o">=</span> <span class="n">is_in_boxes_all</span> <span class="o">|</span> <span class="n">is_in_centers_all</span>
         <span class="n">is_in_boxes_anchor</span> <span class="o">=</span> <span class="n">is_in_boxes_all</span> <span class="o">|</span> <span class="n">is_in_centers_all</span>
 
 
         <span class="c1"># in boxes AND in centers, preserving a shape [num_GTs x num_FGs]</span>
         <span class="c1"># in boxes AND in centers, preserving a shape [num_GTs x num_FGs]</span>
-        <span class="n">is_in_boxes_and_center</span> <span class="o">=</span> <span class="p">(</span><span class="n">is_in_boxes</span><span class="p">[:,</span> <span class="n">is_in_boxes_anchor</span><span class="p">]</span> <span class="o">&amp;</span> <span class="n">is_in_centers</span><span class="p">[:,</span> <span class="n">is_in_boxes_anchor</span><span class="p">])</span>
+        <span class="n">is_in_boxes_and_center</span> <span class="o">=</span> <span class="n">is_in_boxes</span><span class="p">[:,</span> <span class="n">is_in_boxes_anchor</span><span class="p">]</span> <span class="o">&amp;</span> <span class="n">is_in_centers</span><span class="p">[:,</span> <span class="n">is_in_boxes_anchor</span><span class="p">]</span>
         <span class="k">return</span> <span class="n">is_in_boxes_anchor</span><span class="p">,</span> <span class="n">is_in_boxes_and_center</span></div>
         <span class="k">return</span> <span class="n">is_in_boxes_anchor</span><span class="p">,</span> <span class="n">is_in_boxes_and_center</span></div>
 
 
-<div class="viewcode-block" id="YoloXDetectionLoss.dynamic_k_matching"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.YoloXDetectionLoss.dynamic_k_matching">[docs]</a>    <span class="k">def</span> <span class="nf">dynamic_k_matching</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">cost</span><span class="p">,</span> <span class="n">pair_wise_ious</span><span class="p">,</span> <sp
+<div class="viewcode-block" id="YoloXDetectionLoss.dynamic_k_matching"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.YoloXDetectionLoss.dynamic_k_matching">[docs]</a>    <span class="k">def</span> <span class="nf">dynamic_k_matching</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">cost</span><span class="p">,</span> <span class="n">pair_wise_ious</span><span class="p">,</span> <span clas
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        :param cost:            pairwise cost, [num_FGs x num_GTs]</span>
 <span class="sd">        :param cost:            pairwise cost, [num_FGs x num_GTs]</span>
 <span class="sd">        :param pair_wise_ious:  pairwise IoUs, [num_FGs x num_GTs]</span>
 <span class="sd">        :param pair_wise_ious:  pairwise IoUs, [num_FGs x num_GTs]</span>
@@ -632,6 +695,399 @@
 
 
         <span class="n">pred_ious_this_matching</span> <span class="o">=</span> <span class="p">(</span><span class="n">matching_matrix</span> <span class="o">*</span> <span class="n">pair_wise_ious</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">0</span><span class="p">)[</span><span class="n">fg_mask_inboxes</span><span class="p">]</span>
         <span class="n">pred_ious_this_matching</span> <span class="o">=</span> <span class="p">(</span><span class="n">matching_matrix</span> <span class="o">*</span> <span class="n">pair_wise_ious</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">0</span><span class="p">)[</span><span class="n">fg_mask_inboxes</span><span class="p">]</span>
         <span class="k">return</span> <span class="n">num_fg</span><span class="p">,</span> <span class="n">gt_matched_classes</span><span class="p">,</span> <span class="n">pred_ious_this_matching</span><span class="p">,</span> <span class="n">matched_gt_inds</span></div></div>
         <span class="k">return</span> <span class="n">num_fg</span><span class="p">,</span> <span class="n">gt_matched_classes</span><span class="p">,</span> <span class="n">pred_ious_this_matching</span><span class="p">,</span> <span class="n">matched_gt_inds</span></div></div>
+
+
+<div class="viewcode-block" id="YoloXFastDetectionLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.YoloXFastDetectionLoss">[docs]</a><span class="k">class</span> <span class="nc">YoloXFastDetectionLoss</span><span class="p">(</span><span class="n">YoloXDetectionLoss</span><span class="p">):</span>
+    <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">    A completely new implementation of YOLOX loss.</span>
+<span class="sd">    This is NOT an equivalent implementation to the regular yolox loss.</span>
+
+<span class="sd">    * Completely avoids using loops compared to the nested loops in the original implementation.</span>
+<span class="sd">        As a result runs much faster (speedup depends on the type of GPUs, their count, the batch size, etc.).</span>
+<span class="sd">    * Tensors format is very different the original implementation.</span>
+<span class="sd">        Tensors contain image ids, ground truth ids and anchor ids as values to support variable length data.</span>
+<span class="sd">    * There are differences in terms of the algorithm itself:</span>
+<span class="sd">    1. When computing a dynamic k for a ground truth,</span>
+<span class="sd">        in the original implementation they consider the sum of top 10 predictions sorted by ious among the initial</span>
+<span class="sd">        foregrounds of any ground truth in the image,</span>
+<span class="sd">        while in our implementation we consider only the initial foreground of that particular ground truth.</span>
+<span class="sd">        To compensate for that difference we introduce the dynamic_ks_bias hyperparamter which makes the dynamic ks larger.</span>
+<span class="sd">    2. When computing the k matched detections for a ground truth,</span>
+<span class="sd">        in the original implementation they consider the initial foregrounds of any ground truth in the image as candidates,</span>
+<span class="sd">        while in our implementation we consider only the initial foreground of that particular ground truth as candidates.</span>
+<span class="sd">        We believe that this difference is minor.</span>
+
+<span class="sd">    :param dynamic_ks_bias: hyperparameter to compensate for the discrepancies between the regular loss and this loss.</span>
+<span class="sd">    :param sync_num_fgs:    sync num of fgs.</span>
+<span class="sd">                            Can be used for DDP training.</span>
+<span class="sd">    :param obj_loss_fix:    devide by total of num anchors instead num of matching fgs.</span>
+<span class="sd">                            Can be used for objectness loss.</span>
+<span class="sd">    &quot;&quot;&quot;</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">strides</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">,</span> <span class="n">use_l1</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">center_sampling_radius</span><span class="o">=</span><span class="mf">2.5</span><span class="p">,</span> <span class="n">iou_type</span><span class="o">=</span><span class="s2">&quot;iou&quot;</span><s
+    <span class="p">):</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">strides</span><span class="o">=</span><span class="n">strides</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">use_l1</span><span class="o">=</span><span class="n">use_l1</span><span class="p">,</span> <span class="n">center_s
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">dynamic_ks_bias</span> <span class="o">=</span> <span class="n">dynamic_ks_bias</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">sync_num_fgs</span> <span class="o">=</span> <span class="n">sync_num_fgs</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">obj_loss_fix</span> <span class="o">=</span> <span class="n">obj_loss_fix</span>
+
+    <span class="k">def</span> <span class="nf">_compute_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">targets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><
+        <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">        L = L_objectness + L_iou + L_classification + 1[no_aug_epoch]*L_l1</span>
+<span class="sd">        where:</span>
+<span class="sd">            * L_iou, L_classification and L_l1 are calculated only between cells and targets that suit them;</span>
+<span class="sd">            * L_objectness is calculated for all cells.</span>
+
+<span class="sd">        L_classification:</span>
+<span class="sd">            for cells that have suitable ground truths in their grid locations add BCEs</span>
+<span class="sd">            to force a prediction of IoU with a GT in a multi-label way</span>
+<span class="sd">            Coef: 1.</span>
+<span class="sd">        L_iou:</span>
+<span class="sd">            for cells that have suitable ground truths in their grid locations</span>
+<span class="sd">            add (1 - IoU^2), IoU between a predicted box and each GT box, force maximum IoU</span>
+<span class="sd">            Coef: 1.</span>
+<span class="sd">        L_l1:</span>
+<span class="sd">            for cells that have suitable ground truths in their grid locations</span>
+<span class="sd">            l1 distance between the logits and GTs in “logits” format (the inverse of “logits to predictions” ops)</span>
+<span class="sd">            Coef: 1[no_aug_epoch]</span>
+<span class="sd">        L_objectness:</span>
+<span class="sd">            for each cell add BCE with a label of 1 if there is GT assigned to the cell</span>
+<span class="sd">            Coef: 5</span>
+
+<span class="sd">        :param predictions:     output from all Yolo levels, each of shape</span>
+<span class="sd">                                [Batch x Num_Anchors x GridSizeY x GridSizeX x (4 + 1 + Num_classes)]</span>
+<span class="sd">        :param targets:         [Num_targets x (4 + 2)], values on dim 1 are: image id in a batch, class, box x y w h</span>
+
+<span class="sd">        :return:                loss, all losses separately in a detached tensor</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+        <span class="n">x_shifts</span><span class="p">,</span> <span class="n">y_shifts</span><span class="p">,</span> <span class="n">expanded_strides</span><span class="p">,</span> <span class="n">transformed_outputs</span><span class="p">,</span> <span class="n">raw_outputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prepare_predictions</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span>
+
+        <span class="n">bbox_preds</span> <span class="o">=</span> <span class="n">transformed_outputs</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span>  <span class="c1"># [batch, n_anchors_all, 4]</span>
+        <span class="n">obj_preds</span> <span class="o">=</span> <span class="n">transformed_outputs</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">4</span><span class="p">:</span><span class="mi">5</span><span class="p">]</span>  <span class="c1"># [batch, n_anchors_all, 1]</span>
+        <span class="n">cls_preds</span> <span class="o">=</span> <span class="n">transformed_outputs</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">5</span><span class="p">:]</span>  <span class="c1"># [batch, n_anchors_all, n_cls]</span>
+
+        <span class="c1"># assign cells to ground truths, at most one GT per cell</span>
+        <span class="n">matched_fg_ids</span><span class="p">,</span> <span class="n">matched_gt_classes</span><span class="p">,</span> <span class="n">matched_gt_ids</span><span class="p">,</span> <span class="n">matched_img_ids</span><span class="p">,</span> <span class="n">matched_ious</span><span class="p">,</span> <span class="n">flattened_gts</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_compute_matching</span><span class="p">(</span>
+            <span class="n">bbox_preds</span><span class="p">,</span> <span class="n">cls_preds</span><span class="p">,</span> <span class="n">obj_preds</span><span class="p">,</span> <span class="n">expanded_strides</span><span class="p">,</span> <span class="n">x_shifts</span><span class="p">,</span> <span class="n">y_shifts</span><span class="p">,</span> <span class="n">targets</span>
+        <span class="p">)</span>
+
+        <span class="n">num_gts</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">flattened_gts</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">1</span><span class="p">)</span>
+        <span class="n">num_fg</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">matched_gt_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">1</span><span class="p">)</span>
+        <span class="n">total_num_anchors</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">transformed_outputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">transformed_outputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class=
+
+        <span class="n">cls_targets</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">matched_gt_classes</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_classes</span><span class="p">)</span
+        <span class="n">obj_targets</span> <span class="o">=</span> <span class="n">transformed_outputs</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">((</span><span class="n">transformed_outputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">transformed_outputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</s
+        <span class="n">obj_targets</span><span class="p">[</span><span class="n">matched_img_ids</span><span class="p">,</span> <span class="n">matched_fg_ids</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
+        <span class="n">reg_targets</span> <span class="o">=</span> <span class="n">flattened_gts</span><span class="p">[</span><span class="n">matched_gt_ids</span><span class="p">][:,</span> <span class="mi">1</span><span class="p">:]</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_l1</span><span class="p">:</span>
+            <span class="n">l1_targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_l1_target</span><span class="p">(</span>
+                <span class="n">transformed_outputs</span><span class="o">.</span><span class="n">new_zeros</span><span class="p">((</span><span class="n">num_fg</span><span class="p">,</span> <span class="mi">4</span><span class="p">)),</span>
+                <span class="n">flattened_gts</span><span class="p">[</span><span class="n">matched_gt_ids</span><span class="p">][:,</span> <span class="mi">1</span><span class="p">:],</span>
+                <span class="n">expanded_strides</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()[</span><span class="n">matched_fg_ids</span><span class="p">],</span>
+                <span class="n">x_shifts</span><span class="o">=</span><span class="n">x_shifts</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()[</span><span class="n">matched_fg_ids</span><span class="p">],</span>
+                <span class="n">y_shifts</span><span class="o">=</span><span class="n">y_shifts</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()[</span><span class="n">matched_fg_ids</span><span class="p">],</span>
+            <span class="p">)</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">sync_num_fgs</span> <span class="ow">and</span> <span class="n">dist</span><span class="o">.</span><span class="n">group</span><span class="o">.</span><span class="n">WORLD</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">num_fg</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">scalar_tensor</span><span class="p">(</span><span class="n">num_fg</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">matched_gt_ids</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
+            <span class="n">dist</span><span class="o">.</span><span class="n">all_reduce</span><span class="p">(</span><span class="n">num_fg</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">_C</span><span class="o">.</span><span class="n">_distributed_c10d</span><span class="o">.</span><span class="n">ReduceOp</span><span class="o">.</span><span class="n">AVG</span><span class="p">)</span>
+
+        <span class="n">loss_iou</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">iou_loss</span><span class="p">(</span><span class="n">bbox_preds</span><span class="p">[</span><span class="n">matched_img_ids</span><span class="p">,</span> <span class="n">matched_fg_ids</span><span class="p">],</span> <span class="n">reg_targets</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span clas
+        <span class="n">loss_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bcewithlog_loss</span><span class="p">(</span><span class="n">obj_preds</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">obj_targets</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <
+        <span class="n">loss_cls</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bcewithlog_loss</span><span class="p">(</span><span class="n">cls_preds</span><span class="p">[</span><span class="n">matched_img_ids</span><span class="p">,</span> <span class="n">matched_fg_ids</span><span class="p">],</span> <span class="n">cls_targets</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <spa
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_l1</span><span class="p">:</span>
+            <span class="n">loss_l1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">l1_loss</span><span class="p">(</span><span class="n">raw_outputs</span><span class="p">[</span><span class="n">matched_img_ids</span><span class="p">,</span> <span class="n">matched_fg_ids</span><span class="p">],</span> <span class="n">l1_targets</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span cl
+        <span class="k">else</span><span class="p">:</span>
+            <span class="n">loss_l1</span> <span class="o">=</span> <span class="mf">0.0</span>
+
+        <span class="n">reg_weight</span> <span class="o">=</span> <span class="mf">5.0</span>
+        <span class="n">loss</span> <span class="o">=</span> <span class="n">reg_weight</span> <span class="o">*</span> <span class="n">loss_iou</span> <span class="o">+</span> <span class="n">loss_obj</span> <span class="o">+</span> <span class="n">loss_cls</span> <span class="o">+</span> <span class="n">loss_l1</span>
+
+        <span class="k">return</span> <span class="p">(</span>
+            <span class="n">loss</span><span class="p">,</span>
+            <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
+                <span class="p">(</span>
+                    <span class="n">loss_iou</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
+                    <span class="n">loss_obj</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
+                    <span class="n">loss_cls</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
+                    <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">loss_l1</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">transformed_outputs</span><span class="o">.</span><span class="n">device</span><span class="p">),</spa
+                    <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">num_fg</span> <span class="o">/</span> <span class="nb">max</span><span class="p">(</span><span class="n">num_gts</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span clas
+                    <span class="n">loss</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span>
+                <span class="p">)</span>
+            <span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">(),</span>
+        <span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_get_initial_matching</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span> <span class="n">gt_bboxes</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">expanded_strides</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">x_shifts</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</sp
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]:</span>
+        <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">        Get candidates using a mask for all cells.</span>
+<span class="sd">        Mask in only foreground cells that have a center located:</span>
+<span class="sd">            * withing a GT box (param: is_in_boxes);</span>
+<span class="sd">            OR</span>
+<span class="sd">            * within a fixed radius around a GT box (center sampling) (param: is_in_centers);</span>
+
+<span class="sd">        return:</span>
+<span class="sd">            initial_matching: get a list of candidates pairs of (gt box id, anchor box id) based on cell = is_in_boxes | is_in_centers.</span>
+<span class="sd">                              shape: [num_candidates, 2]</span>
+<span class="sd">            strong candidate mask: get a list whether a candidate is a strong one or not.</span>
+<span class="sd">                                   strong candidate is a cell from is_in_boxes &amp; is_in_centers.</span>
+<span class="sd">                                   shape: [num_candidates].</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+        <span class="n">cell_x_centers</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_shifts</span> <span class="o">+</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">*</span> <span class="n">expanded_strides</span>
+        <span class="n">cell_y_centers</span> <span class="o">=</span> <span class="p">(</span><span class="n">y_shifts</span> <span class="o">+</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">*</span> <span class="n">expanded_strides</span>
+
+        <span class="n">gt_bboxes_x_centers</span> <span class="o">=</span> <span class="n">gt_bboxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
+        <span class="n">gt_bboxes_y_centers</span> <span class="o">=</span> <span class="n">gt_bboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
+
+        <span class="n">gt_bboxes_half_w</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">gt_bboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
+        <span class="n">gt_bboxes_half_h</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">gt_bboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
+
+        <span class="n">is_in_boxes</span> <span class="o">=</span> <span class="p">(</span>
+            <span class="p">(</span><span class="n">cell_x_centers</span> <span class="o">&gt;</span> <span class="n">gt_bboxes_x_centers</span> <span class="o">-</span> <span class="n">gt_bboxes_half_w</span><span class="p">)</span>
+            <span class="o">&amp;</span> <span class="p">(</span><span class="n">gt_bboxes_x_centers</span> <span class="o">+</span> <span class="n">gt_bboxes_half_w</span> <span class="o">&gt;</span> <span class="n">cell_x_centers</span><span class="p">)</span>
+            <span class="o">&amp;</span> <span class="p">(</span><span class="n">cell_y_centers</span> <span class="o">&gt;</span> <span class="n">gt_bboxes_y_centers</span> <span class="o">-</span> <span class="n">gt_bboxes_half_h</span><span class="p">)</span>
+            <span class="o">&amp;</span> <span class="p">(</span><span class="n">gt_bboxes_y_centers</span> <span class="o">+</span> <span class="n">gt_bboxes_half_h</span> <span class="o">&gt;</span> <span class="n">cell_y_centers</span><span class="p">)</span>
+        <span class="p">)</span>
+
+        <span class="n">radius_shifts</span> <span class="o">=</span> <span class="mf">2.5</span> <span class="o">*</span> <span class="n">expanded_strides</span>
+
+        <span class="n">is_in_centers</span> <span class="o">=</span> <span class="p">(</span>
+            <span class="p">(</span><span class="n">cell_x_centers</span> <span class="o">+</span> <span class="n">radius_shifts</span> <span class="o">&gt;</span> <span class="n">gt_bboxes_x_centers</span><span class="p">)</span>
+            <span class="o">&amp;</span> <span class="p">(</span><span class="n">gt_bboxes_x_centers</span> <span class="o">&gt;</span> <span class="n">cell_x_centers</span> <span class="o">-</span> <span class="n">radius_shifts</span><span class="p">)</span>
+            <span class="o">&amp;</span> <span class="p">(</span><span class="n">cell_y_centers</span> <span class="o">+</span> <span class="n">radius_shifts</span> <span class="o">&gt;</span> <span class="n">gt_bboxes_y_centers</span><span class="p">)</span>
+            <span class="o">&amp;</span> <span class="p">(</span><span class="n">gt_bboxes_y_centers</span> <span class="o">&gt;</span> <span class="n">cell_y_centers</span> <span class="o">-</span> <span class="n">radius_shifts</span><span class="p">)</span>
+        <span class="p">)</span>
+
+        <span class="n">initial_mask</span> <span class="o">=</span> <span class="n">is_in_boxes</span> <span class="o">|</span> <span class="n">is_in_centers</span>
+        <span class="n">initial_matching</span> <span class="o">=</span> <span class="n">initial_mask</span><span class="o">.</span><span class="n">nonzero</span><span class="p">()</span>
+        <span class="n">strong_candidate_mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">is_in_boxes</span> <span class="o">&amp;</span> <span class="n">is_in_centers</span><span class="p">)[</span><span class="n">initial_mask</span><span class="p">]</span>
+
+        <span class="k">return</span> <span class="n">initial_matching</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">initial_matching</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">strong_candidate_mask</span>
+
+    <span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
+    <span class="k">def</span> <span class="nf">_compute_matching</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">bbox_preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
+        <span class="n">cls_preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
+        <span class="n">obj_preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
+        <span class="n">expanded_strides</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
+        <span class="n">x_shifts</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
+        <span class="n">y_shifts</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
+        <span class="n">labels</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
+        <span class="n">ious_loss_cost_coeff</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">3.0</span><span class="p">,</span>
+        <span class="n">outside_boxes_and_center_cost_coeff</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">100000.0</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Ten
+        <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">        Match cells to ground truth:</span>
+<span class="sd">            * at most 1 GT per cell</span>
+<span class="sd">            * dynamic number of cells per GT</span>
+
+<span class="sd">        :param bbox_preds: predictions of bounding boxes. shape [batch, n_anchors_all, 4]</span>
+<span class="sd">        :param cls_preds:  predictions of class.          shape [batch, n_anchors_all, n_cls]</span>
+<span class="sd">        :param obj_preds:  predictions for objectness.    shape [batch, n_anchors_all, 1]</span>
+<span class="sd">        :param expanded_strides:  stride of the output grid the prediction is coming from. shape [1, n_anchors_all]</span>
+<span class="sd">        :param x_shifts: x coordinate on the grid cell the prediction is coming from.      shape [1, n_anchors_all]</span>
+<span class="sd">        :param y_shifts: y coordinate on the grid cell the prediction is coming from.      shape [1, n_anchors_all]</span>
+<span class="sd">        :param labels:   labels for each grid cell.  shape [n_anchors_all, (4 + 2)]</span>
+<span class="sd">        :return: candidate_fg_ids       shape [num_fg]</span>
+<span class="sd">                 candidate_gt_classes   shape [num_fg]</span>
+<span class="sd">                 candidate_gt_ids       shape [num_fg]</span>
+<span class="sd">                 candidate_img_ids      shape [num_fg]</span>
+<span class="sd">                 candidate_ious         shape [num_fg]</span>
+<span class="sd">                 flattened_gts          shape [num_gts, 5]</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+
+        <span class="n">flattened_gts</span><span class="p">,</span> <span class="n">gt_id_to_img_id</span> <span class="o">=</span> <span class="n">labels</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">:],</span> <span class="n">labels</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</spa
+
+        <span class="c1"># COMPUTE CANDIDATES</span>
+        <span class="n">candidate_gt_ids</span><span class="p">,</span> <span class="n">candidate_fg_ids</span><span class="p">,</span> <span class="n">strong_candidate_mask</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_initial_matching</span><span class="p">(</span><span class="n">flattened_gts</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">:],</span> <span class="n">expanded_strides</span><span class="p">,</
+        <span class="n">candidate_img_ids</span> <span class="o">=</span> <span class="n">gt_id_to_img_id</span><span class="p">[</span><span class="n">candidate_gt_ids</span><span class="p">]</span>
+        <span class="n">candidate_gts_bbox</span> <span class="o">=</span> <span class="n">flattened_gts</span><span class="p">[</span><span class="n">candidate_gt_ids</span><span class="p">,</span> <span class="mi">1</span><span class="p">:]</span>
+        <span class="n">candidate_det_bbox</span> <span class="o">=</span> <span class="n">bbox_preds</span><span class="p">[</span><span class="n">candidate_img_ids</span><span class="p">,</span> <span class="n">candidate_fg_ids</span><span class="p">]</span>
+
+        <span class="c1"># COMPUTE DYNAMIC KS</span>
+        <span class="n">candidate_ious</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_calculate_pairwise_bbox_iou</span><span class="p">(</span><span class="n">candidate_gts_bbox</span><span class="p">,</span> <span class="n">candidate_det_bbox</span><span class="p">,</span> <span class="n">xyxy</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
+        <span class="n">dynamic_ks</span><span class="p">,</span> <span class="n">matching_index_to_dynamic_k_index</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_compute_dynamic_ks</span><span class="p">(</span><span class="n">candidate_gt_ids</span><span class="p">,</span> <span class="n">candidate_ious</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dynamic_ks_bias</span><span class="p">)
+        <span class="k">del</span> <span class="n">candidate_gts_bbox</span><span class="p">,</span> <span class="n">candidate_det_bbox</span>
+
+        <span class="c1"># ORDER CANDIDATES BY COST</span>
+        <span class="n">candidate_gt_classes</span> <span class="o">=</span> <span class="n">flattened_gts</span><span class="p">[</span><span class="n">candidate_gt_ids</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span>
+        <span class="n">cost_order</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_compute_cost_order</span><span class="p">(</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">num_classes</span><span class="p">,</span>
+            <span class="n">candidate_img_ids</span><span class="p">,</span>
+            <span class="n">candidate_gt_classes</span><span class="p">,</span>
+            <span class="n">candidate_fg_ids</span><span class="p">,</span>
+            <span class="n">candidate_ious</span><span class="p">,</span>
+            <span class="n">cls_preds</span><span class="p">,</span>
+            <span class="n">obj_preds</span><span class="p">,</span>
+            <span class="n">strong_candidate_mask</span><span class="p">,</span>
+            <span class="n">ious_loss_cost_coeff</span><span class="p">,</span>
+            <span class="n">outside_boxes_and_center_cost_coeff</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="n">candidate_gt_ids</span> <span class="o">=</span> <span class="n">candidate_gt_ids</span><span class="p">[</span><span class="n">cost_order</span><span class="p">]</span>
+        <span class="n">candidate_gt_classes</span> <span class="o">=</span> <span class="n">candidate_gt_classes</span><span class="p">[</span><span class="n">cost_order</span><span class="p">]</span>
+        <span class="n">candidate_img_ids</span> <span class="o">=</span> <span class="n">candidate_img_ids</span><span class="p">[</span><span class="n">cost_order</span><span class="p">]</span>
+        <span class="n">candidate_fg_ids</span> <span class="o">=</span> <span class="n">candidate_fg_ids</span><span class="p">[</span><span class="n">cost_order</span><span class="p">]</span>
+        <span class="n">candidate_ious</span> <span class="o">=</span> <span class="n">candidate_ious</span><span class="p">[</span><span class="n">cost_order</span><span class="p">]</span>
+        <span class="n">matching_index_to_dynamic_k_index</span> <span class="o">=</span> <span class="n">matching_index_to_dynamic_k_index</span><span class="p">[</span><span class="n">cost_order</span><span class="p">]</span>
+        <span class="k">del</span> <span class="n">cost_order</span>
+
+        <span class="c1"># FILTER MATCHING TO LOWEST K COST MATCHES PER GT</span>
+        <span class="n">ranks</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_compute_ranks</span><span class="p">(</span><span class="n">candidate_gt_ids</span><span class="p">)</span>
+        <span class="n">corresponding_dynamic_ks</span> <span class="o">=</span> <span class="n">dynamic_ks</span><span class="p">[</span><span class="n">matching_index_to_dynamic_k_index</span><span class="p">]</span>
+        <span class="n">topk_mask</span> <span class="o">=</span> <span class="n">ranks</span> <span class="o">&lt;</span> <span class="n">corresponding_dynamic_ks</span>
+
+        <span class="n">candidate_gt_ids</span> <span class="o">=</span> <span class="n">candidate_gt_ids</span><span class="p">[</span><span class="n">topk_mask</span><span class="p">]</span>
+        <span class="n">candidate_gt_classes</span> <span class="o">=</span> <span class="n">candidate_gt_classes</span><span class="p">[</span><span class="n">topk_mask</span><span class="p">]</span>
+        <span class="n">candidate_img_ids</span> <span class="o">=</span> <span class="n">candidate_img_ids</span><span class="p">[</span><span class="n">topk_mask</span><span class="p">]</span>
+        <span class="n">candidate_fg_ids</span> <span class="o">=</span> <span class="n">candidate_fg_ids</span><span class="p">[</span><span class="n">topk_mask</span><span class="p">]</span>
+        <span class="n">candidate_ious</span> <span class="o">=</span> <span class="n">candidate_ious</span><span class="p">[</span><span class="n">topk_mask</span><span class="p">]</span>
+        <span class="k">del</span> <span class="n">ranks</span><span class="p">,</span> <span class="n">topk_mask</span><span class="p">,</span> <span class="n">dynamic_ks</span><span class="p">,</span> <span class="n">matching_index_to_dynamic_k_index</span><span class="p">,</span> <span class="n">corresponding_dynamic_ks</span>
+
+        <span class="c1"># FILTER MATCHING TO AT MOST 1 MATCH FOR DET BY TAKING THE LOWEST COST MATCH</span>
+        <span class="n">candidate_img_and_fg_ids_combined</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_combine_candidates_img_id_fg_id</span><span class="p">(</span><span class="n">candidate_img_ids</span><span class="p">,</span> <span class="n">candidate_fg_ids</span><span class="p">)</span>
+        <span class="n">top1_mask</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_compute_is_first_mask</span><span class="p">(</span><span class="n">candidate_img_and_fg_ids_combined</span><span class="p">)</span>
+        <span class="n">candidate_gt_ids</span> <span class="o">=</span> <span class="n">candidate_gt_ids</span><span class="p">[</span><span class="n">top1_mask</span><span class="p">]</span>
+        <span class="n">candidate_gt_classes</span> <span class="o">=</span> <span class="n">candidate_gt_classes</span><span class="p">[</span><span class="n">top1_mask</span><span class="p">]</span>
+        <span class="n">candidate_fg_ids</span> <span class="o">=</span> <span class="n">candidate_fg_ids</span><span class="p">[</span><span class="n">top1_mask</span><span class="p">]</span>
+        <span class="n">candidate_img_ids</span> <span class="o">=</span> <span class="n">candidate_img_ids</span><span class="p">[</span><span class="n">top1_mask</span><span class="p">]</span>
+        <span class="n">candidate_ious</span> <span class="o">=</span> <span class="n">candidate_ious</span><span class="p">[</span><span class="n">top1_mask</span><span class="p">]</span>
+
+        <span class="k">return</span> <span class="n">candidate_fg_ids</span><span class="p">,</span> <span class="n">candidate_gt_classes</span><span class="p">,</span> <span class="n">candidate_gt_ids</span><span class="p">,</span> <span class="n">candidate_img_ids</span><span class="p">,</span> <span class="n">candidate_ious</span><span class="p">,</span> <span class="n">flattened_gts</span>
+
+    <span class="k">def</span> <span class="nf">_combine_candidates_img_id_fg_id</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">candidate_img_ids</span><span class="p">,</span> <span class="n">candidate_anchor_ids</span><span class="p">):</span>
+        <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">        Create one dim tensor with unique pairs of img_id and fg_id.</span>
+<span class="sd">        e.g: candidate_img_ids = [0,1,0,0]</span>
+<span class="sd">             candidate_fg_ids = [0,0,0,1]</span>
+<span class="sd">             result = [0,1,0,2]</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+        <span class="n">candidate_img_and_fg_ids_combined</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">((</span><span class="n">candidate_img_ids</span><span class="p">,</span> <span class="n">candidate_anchor_ids</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">unique</span><span class="p">(<
+        <span class="k">return</span> <span class="n">candidate_img_and_fg_ids_combined</span>
+
+    <span class="k">def</span> <span class="nf">_compute_dynamic_ks</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">ious</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">dynamic_ks_bias</
+        <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">        :param ids:                 ids of GTs, shape: [num_candidates]</span>
+<span class="sd">        :param ious:                pairwise IoUs, shape: [num_candidates]</span>
+<span class="sd">        :param dynamic_ks_bias:     multiply the resulted k to compensate the regular loss</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">ids</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;ids must be of shape [num_candidates]&quot;</span>
+        <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">ious</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;ious must be of shape [num_candidates]&quot;</span>
+        <span class="k">assert</span> <span class="n">ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">ious</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="s2">&quot;num of ids.shape[0] must be the same as num of ious.shape[0]&quot;</span>
+        <span class="c1"># sort ious and ids by ious</span>
+        <span class="n">ious</span><span class="p">,</span> <span class="n">ious_argsort</span> <span class="o">=</span> <span class="n">ious</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">descending</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+        <span class="n">ids</span> <span class="o">=</span> <span class="n">ids</span><span class="p">[</span><span class="n">ious_argsort</span><span class="p">]</span>
+
+        <span class="c1"># stable sort indices, so that ious are first sorted by id and second by value</span>
+        <span class="n">ids</span><span class="p">,</span> <span class="n">ids_argsort</span> <span class="o">=</span> <span class="n">ids</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">stable</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+        <span class="n">ious</span> <span class="o">=</span> <span class="n">ious</span><span class="p">[</span><span class="n">ids_argsort</span><span class="p">]</span>
+
+        <span class="n">unique_ids</span><span class="p">,</span> <span class="n">ids_index_to_unique_ids_index</span> <span class="o">=</span> <span class="n">ids</span><span class="o">.</span><span class="n">unique_consecutive</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">return_inverse</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+        <span class="n">num_unique_ids</span> <span class="o">=</span> <span class="n">unique_ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
+
+        <span class="k">if</span> <span class="n">ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mi">10</span><span class="p">:</span>
+            <span class="n">is_in_top_10</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">10</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span clas
+        <span class="k">else</span><span class="p">:</span>
+            <span class="n">is_in_top_10</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">)</span>
+
+        <span class="n">dynamic_ks</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">num_unique_ids</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">ious</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">ious</span><span class="o">.</span><sp
+        <span class="n">dynamic_ks</span><span class="o">.</span><span class="n">index_put_</span><span class="p">((</span><span class="n">ids_index_to_unique_ids_index</span><span class="p">,),</span> <span class="n">is_in_top_10</span> <span class="o">*</span> <span class="n">ious</span><span class="p">,</span> <span class="n">accumulate</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+        <span class="k">if</span> <span class="n">dynamic_ks_bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">dynamic_ks</span> <span class="o">*=</span> <span class="n">dynamic_ks_bias</span>
+        <span class="n">dynamic_ks</span> <span class="o">=</span> <span class="n">dynamic_ks</span><span class="o">.</span><span class="n">long</span><span class="p">()</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
+
+        <span class="n">all_argsort</span> <span class="o">=</span> <span class="n">ious_argsort</span><span class="p">[</span><span class="n">ids_argsort</span><span class="p">]</span>
+        <span class="n">inverse_all_argsort</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">ious_argsort</span><span class="p">)</span>
+        <span class="n">inverse_all_argsort</span><span class="p">[</span><span class="n">all_argsort</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">all_argsort</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">all_arg
+
+        <span class="k">return</span> <span class="n">dynamic_ks</span><span class="p">,</span> <span class="n">ids_index_to_unique_ids_index</span><span class="p">[</span><span class="n">inverse_all_argsort</span><span class="p">]</span>
+
+    <span class="k">def</span> <span class="nf">_compute_cost_order</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">num_classes</span><span class="p">,</span>
+        <span class="n">candidate_gt_img_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
+        <span class="n">candidate_gt_classes</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
+        <span class="n">candidate_anchor_ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
+        <span class="n">candidate_ious</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
+        <span class="n">cls_preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
+        <span class="n">obj_preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
+        <span class="n">strong_candidate_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
+        <span class="n">ious_loss_cost_coeff</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
+        <span class="n">outside_boxes_and_center_cost_coeff</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
+    <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
+        <span class="n">gt_cls_per_image</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">candidate_gt_classes</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">int64</span><span class="p">),</span> <span class="n">num_classes</span><span class="p">)</span><span class="o">.</span><span class="n">floa
+        <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">amp</span><span class="o">.</span><span class="n">autocast</span><span class="p">(</span><span class="n">enabled</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
+            <span class="n">cls_preds_</span> <span class="o">=</span> <span class="p">(</span>
+                <span class="n">cls_preds</span><span class="p">[</span><span class="n">candidate_gt_img_ids</span><span class="p">,</span> <span class="n">candidate_anchor_ids</span><span class="p">]</span><span class="o">.</span><span class="n">float</span><span class="p">()</span><span class="o">.</span><span class="n">sigmoid_</span><span class="p">()</span>
+                <span class="o">*</span> <span class="n">obj_preds</span><span class="p">[</span><span class="n">candidate_gt_img_ids</span><span class="p">,</span> <span class="n">candidate_anchor_ids</span><span class="p">]</span><span class="o">.</span><span class="n">float</span><span class="p">()</span><span class="o">.</span><span class="n">sigmoid_</span><span class="p">()</span>
+            <span class="p">)</span>
+            <span class="n">pair_wise_cls_cost</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">binary_cross_entropy</span><span class="p">(</span><span class="n">cls_preds_</span><span class="o">.</span><span class="n">sqrt_</span><span class="p">(),</span> <span class="n">gt_cls_per_image</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">)</span><span cla
+
+        <span class="n">ious_cost</span> <span class="o">=</span> <span class="o">-</span><span class="n">torch</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">candidate_ious</span> <span class="o">+</span> <span class="mf">1e-8</span><span class="p">)</span>
+        <span class="n">cost</span> <span class="o">=</span> <span class="n">pair_wise_cls_cost</span> <span class="o">+</span> <span class="n">ious_loss_cost_coeff</span> <span class="o">*</span> <span class="n">ious_cost</span> <span class="o">+</span> <span class="n">outside_boxes_and_center_cost_coeff</span> <span class="o">*</span> <span class="n">strong_candidate_mask</span><span class="o">.</span><span class="n">logical_not</span><span class="p">()</span>
+        <span class="k">return</span> <span class="n">cost</span><span class="o">.</span><span class="n">argsort</span><span class="p">()</span>
+
+    <span class="k">def</span> <span class="nf">_calculate_pairwise_bbox_iou</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">bboxes_a</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">bboxes_b</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n"
+        <span class="k">if</span> <span class="n">bboxes_a</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">!=</span> <span class="mi">4</span> <span class="ow">or</span> <span class="n">bboxes_b</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">!=</span> <span class="mi">4</span><span class="p">:</sp
+            <span class="k">raise</span> <span class="ne">IndexError</span>
+
+        <span class="k">if</span> <span class="n">xyxy</span><span class="p">:</span>
+            <span class="n">tl</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">bboxes_a</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="n">bboxes_b</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">])</span>
+            <span class="n">br</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">bboxes_a</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:],</span> <span class="n">bboxes_b</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:])</span>
+            <span class="n">area_a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">bboxes_a</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">-</span> <span class="n">bboxes_a</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="mi">1</span><span class="p">)</span>
+            <span class="n">area_b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">bboxes_b</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">-</span> <span class="n">bboxes_b</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="mi">1</span><span class="p">)</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="n">tl</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span>
+                <span class="p">(</span><span class="n">bboxes_a</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">bboxes_a</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">),</span>
+                <span class="p">(</span><span class="n">bboxes_b</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">bboxes_b</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">),</span>
+            <span class="p">)</span>
+            <span class="n">br</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span>
+                <span class="p">(</span><span class="n">bboxes_a</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">bboxes_a</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">),</span>
+                <span class="p">(</span><span class="n">bboxes_b</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">bboxes_b</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">),</span>
+            <span class="p">)</span>
+
+            <span class="n">area_a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">bboxes_a</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:],</span> <span class="mi">1</span><span class="p">)</span>
+            <span class="n">area_b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">bboxes_b</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:],</span> <span class="mi">1</span><span class="p">)</span>
+        <span class="n">en</span> <span class="o">=</span> <span class="p">(</span><span class="n">tl</span> <span class="o">&lt;</span> <span class="n">br</span><span class="p">)</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
+        <span class="n">area_i</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">br</span> <span class="o">-</span> <span class="n">tl</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">en</span>
+        <span class="k">return</span> <span class="n">area_i</span> <span class="o">/</span> <span class="p">(</span><span class="n">area_a</span> <span class="o">+</span> <span class="n">area_b</span> <span class="o">-</span> <span class="n">area_i</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_compute_ranks</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
+        <span class="n">ids</span><span class="p">,</span> <span class="n">ids_argsort</span> <span class="o">=</span> <span class="n">ids</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">stable</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+
+        <span class="k">if</span> <span class="n">ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
+            <span class="n">is_not_first</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span clas
+        <span class="k">else</span><span class="p">:</span>
+            <span class="n">is_not_first</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">)</span>
+
+        <span class="n">subtract</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">ids_argsort</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="
+        <span class="n">subtract</span><span class="p">[</span><span class="n">is_not_first</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
+        <span class="n">subtract</span> <span class="o">=</span> <span class="n">subtract</span><span class="o">.</span><span class="n">cummax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
+        <span class="n">rank</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">ids_argsort</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">d
+
+        <span class="n">inverse_argsort</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">ids_argsort</span><span class="p">)</span>
+        <span class="n">inverse_argsort</span><span class="p">[</span><span class="n">ids_argsort</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">ids_argsort</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">ids_argsort
+
+        <span class="k">return</span> <span class="n">rank</span><span class="p">[</span><span class="n">inverse_argsort</span><span class="p">]</span>
+
+    <span class="k">def</span> <span class="nf">_compute_is_first_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ids</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
+        <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">        Filter fg that matches two gts.</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+        <span class="n">ids</span><span class="p">,</span> <span class="n">ids_argsort</span> <span class="o">=</span> <span class="n">ids</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">stable</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+
+        <span class="k">if</span> <span class="n">ids</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
+            <span class="n">is_first</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n"
+        <span class="k">else</span><span class="p">:</span>
+            <span class="n">is_first</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">)</span>
+
+        <span class="n">inverse_argsort</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">ids_argsort</span><span class="p">)</span>
+        <span class="n">inverse_argsort</span><span class="p">[</span><span class="n">ids_argsort</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">ids_argsort</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">ids_argsort
+
+        <span class="k">return</span> <span class="n">is_first</span><span class="p">[</span><span class="n">inverse_argsort</span><span class="p">]</span></div>
 </pre></div>
 </pre></div>
 
 
            </div>
            </div>
@@ -661,4 +1117,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.metrics.classification_metrics &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.metrics.classification_metrics &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -92,7 +94,7 @@
 <span class="kn">from</span> <span class="nn">torchmetrics</span> <span class="kn">import</span> <span class="n">Metric</span>
 <span class="kn">from</span> <span class="nn">torchmetrics</span> <span class="kn">import</span> <span class="n">Metric</span>
 
 
 
 
-<div class="viewcode-block" id="accuracy"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.accuracy">[docs]</a><span class="k">def</span> <span class="nf">accuracy</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">topk</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,)):</span>
+<div class="viewcode-block" id="accuracy"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.accuracy">[docs]</a><span class="k">def</span> <span class="nf">accuracy</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">topk</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,)):</span>
     <span class="sd">&quot;&quot;&quot;Computes the precision@k for the specified values of k</span>
     <span class="sd">&quot;&quot;&quot;Computes the precision@k for the specified values of k</span>
 <span class="sd">    :param output: Tensor / Numpy / List</span>
 <span class="sd">    :param output: Tensor / Numpy / List</span>
 <span class="sd">        The prediction</span>
 <span class="sd">        The prediction</span>
@@ -122,24 +124,26 @@
     <span class="k">return</span> <span class="n">res</span></div>
     <span class="k">return</span> <span class="n">res</span></div>
 
 
 
 
-<div class="viewcode-block" id="Accuracy"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.Accuracy">[docs]</a><span class="k">class</span> <span class="nc">Accuracy</span><span class="p">(</span><span class="n">torchmetrics</span><span class="o">.</span><span class="n">Accuracy</span><span class="p">):</span>
+<div class="viewcode-block" id="Accuracy"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.Accuracy">[docs]</a><span class="k">class</span> <span class="nc">Accuracy</span><span class="p">(</span><span class="n">torchmetrics</span><span class="o">.</span><span class="n">Accuracy</span><span class="p">):</span>
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">)</span>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="o">=</span> <span class="kc">True</span>
 
 
-<div class="viewcode-block" id="Accuracy.update"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.Accuracy.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class=
+<div class="viewcode-block" id="Accuracy.update"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.Accuracy.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">targ
         <span class="k">if</span> <span class="n">target</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="n">preds</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">target</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="n">preds</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
             <span class="n">target</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># supports smooth labels</span>
             <span class="n">target</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># supports smooth labels</span>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">preds</span><span class="o">=</span><span class="n">preds</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">target</span><span class="o">=</span><span class="n">target</span><span class="p">)</span></div></div>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">preds</span><span class="o">=</span><span class="n">preds</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">target</span><span class="o">=</span><span class="n">target</span><span class="p">)</span></div></div>
 
 
 
 
-<div class="viewcode-block" id="Top5"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.Top5">[docs]</a><span class="k">class</span> <span class="nc">Top5</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
+<div class="viewcode-block" id="Top5"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.Top5">[docs]</a><span class="k">class</span> <span class="nc">Top5</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">)</span>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="o">=</span> <span class="kc">True</span>
 
 
-        <span class="bp">self</span><span class="o">.</span><span class="n">add_state</span><span class="p">(</span><span class="s2">&quot;correct&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">0.</span><span class="p">),</span> <span class="n">dist_reduce_fx</span><span class="o">=</span><span class="s2">&quot;sum&quot;</span><spa
+        <span class="bp">self</span><span class="o">.</span><span class="n">add_state</span><span class="p">(</span><span class="s2">&quot;correct&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">0.0</span><span class="p">),</span> <span class="n">dist_reduce_fx</span><span class="o">=</span><span class="s2">&quot;sum&quot;</span><sp
         <span class="bp">self</span><span class="o">.</span><span class="n">add_state</span><span class="p">(</span><span class="s2">&quot;total&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">dist_reduce_fx</span><span class="o">=</span><span class="s2">&quot;sum&quot;</span><span c
         <span class="bp">self</span><span class="o">.</span><span class="n">add_state</span><span class="p">(</span><span class="s2">&quot;total&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">dist_reduce_fx</span><span class="o">=</span><span class="s2">&quot;sum&quot;</span><span c
 
 
-<div class="viewcode-block" id="Top5.update"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.Top5.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">targ
+<div class="viewcode-block" id="Top5.update"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.Top5.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">target</span
         <span class="k">if</span> <span class="n">target</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="n">preds</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">target</span><span class="o">.</span><span class="n">shape</span> <span class="o">==</span> <span class="n">preds</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
             <span class="n">target</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># supports smooth labels</span>
             <span class="n">target</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># supports smooth labels</span>
 
 
@@ -155,21 +159,22 @@
         <span class="bp">self</span><span class="o">.</span><span class="n">correct</span> <span class="o">+=</span> <span class="n">correct5</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">correct</span> <span class="o">+=</span> <span class="n">correct5</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">total</span> <span class="o">+=</span> <span class="n">batch_size</span></div>
         <span class="bp">self</span><span class="o">.</span><span class="n">total</span> <span class="o">+=</span> <span class="n">batch_size</span></div>
 
 
-<div class="viewcode-block" id="Top5.compute"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.Top5.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+<div class="viewcode-block" id="Top5.compute"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.Top5.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
         <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">correct</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">total</span></div></div>
         <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">correct</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">total</span></div></div>
 
 
 
 
-<div class="viewcode-block" id="ToyTestClassificationMetric"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.ToyTestClassificationMetric">[docs]</a><span class="k">class</span> <span class="nc">ToyTestClassificationMetric</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
+<div class="viewcode-block" id="ToyTestClassificationMetric"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.ToyTestClassificationMetric">[docs]</a><span class="k">class</span> <span class="nc">ToyTestClassificationMetric</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Dummy classification Mettric object returning 0 always (for testing).</span>
 <span class="sd">    Dummy classification Mettric object returning 0 always (for testing).</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
+
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">)</span>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">)</span>
 
 
-<div class="viewcode-block" id="ToyTestClassificationMetric.update"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.ToyTestClassificationMetric.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span
+<div class="viewcode-block" id="ToyTestClassificationMetric.update"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.ToyTestClassificationMetric.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span c
         <span class="k">pass</span></div>
         <span class="k">pass</span></div>
 
 
-<div class="viewcode-block" id="ToyTestClassificationMetric.compute"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.ToyTestClassificationMetric.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+<div class="viewcode-block" id="ToyTestClassificationMetric.compute"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.ToyTestClassificationMetric.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
         <span class="k">return</span> <span class="mi">0</span></div></div>
         <span class="k">return</span> <span class="mi">0</span></div></div>
 </pre></div>
 </pre></div>
 
 
@@ -200,4 +205,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.metrics.detection_metrics &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.metrics.detection_metrics &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -94,10 +96,11 @@
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">compute_detection_matching</span><span class="p">,</span> <span class="n">compute_detection_metrics</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">compute_detection_matching</span><span class="p">,</span> <span class="n">compute_detection_metrics</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">DetectionPostPredictionCallback</span><span class="p">,</span> <span class="n">IouThreshold</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">DetectionPostPredictionCallback</span><span class="p">,</span> <span class="n">IouThreshold</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
+
 <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
 <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
 
 
 
 
-<div class="viewcode-block" id="DetectionMetrics"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.DetectionMetrics">[docs]</a><span class="k">class</span> <span class="nc">DetectionMetrics</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
+<div class="viewcode-block" id="DetectionMetrics"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.DetectionMetrics">[docs]</a><span class="k">class</span> <span class="nc">DetectionMetrics</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    DetectionMetrics</span>
 <span class="sd">    DetectionMetrics</span>
 
 
@@ -121,37 +124,52 @@
 <span class="sd">        accumulate_on_cpu:     Run on CPU regardless of device used in other parts.</span>
 <span class="sd">        accumulate_on_cpu:     Run on CPU regardless of device used in other parts.</span>
 <span class="sd">                            This is to avoid &quot;CUDA out of memory&quot; that might happen on GPU (default False)</span>
 <span class="sd">                            This is to avoid &quot;CUDA out of memory&quot; that might happen on GPU (default False)</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
-    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_cls</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
-                 <span class="n">post_prediction_callback</span><span class="p">:</span> <span class="n">DetectionPostPredictionCallback</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
-                 <span class="n">normalize_targets</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
-                 <span class="n">iou_thres</span><span class="p">:</span> <span class="n">IouThreshold</span> <span class="o">=</span> <span class="n">IouThreshold</span><span class="o">.</span><span class="n">MAP_05_TO_095</span><span class="p">,</span>
-                 <span class="n">recall_thres</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
-                 <span class="n">score_thres</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
-                 <span class="n">top_k_predictions</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
-                 <span class="n">dist_sync_on_step</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
-                 <span class="n">accumulate_on_cpu</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">num_cls</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
+        <span class="n">post_prediction_callback</span><span class="p">:</span> <span class="n">DetectionPostPredictionCallback</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">normalize_targets</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+        <span class="n">iou_thres</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">IouThreshold</span><span class="p">,</span> <span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="n">IouThreshold</span><span class="o">.</span><span class="n">MAP_05_TO_095</span><span class="p">,</span>
+        <span class="n">recall_thres</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">score_thres</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
+        <span class="n">top_k_predictions</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
+        <span class="n">dist_sync_on_step</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+        <span class="n">accumulate_on_cpu</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
+    <span class="p">):</span>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">)</span>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">num_cls</span> <span class="o">=</span> <span class="n">num_cls</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">num_cls</span> <span class="o">=</span> <span class="n">num_cls</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">iou_thres</span> <span class="o">=</span> <span class="n">iou_thres</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">iou_thres</span> <span class="o">=</span> <span class="n">iou_thres</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">map_str</span> <span class="o">=</span> <span class="s1">&#39;mAP@</span><span class="si">%.1f</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="n">iou_thres</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">iou_thres</span><span class="o">.</span><span class="n">is_range</span><span class="p">()</
-        <span class="bp">self</span><span class="o">.</span><span class="n">component_names</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;Precision&quot;</span><span class="p">,</span> <span class="s2">&quot;Recall&quot;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">map_str</span><span class="p">,</span> <span class="s2">&quot;F1&quot;</span><span class="p">]</span>
+
+        <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">iou_thres</span><span class="p">,</span> <span class="n">IouThreshold</span><span class="p">):</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresholds</span> <span class="o">=</span> <span class="n">iou_thres</span><span class="o">.</span><span class="n">to_tensor</span><span class="p">()</span>
+        <span class="k">else</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresholds</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="n">iou_thres</span><span class="p">])</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">map_str</span> <span class="o">=</span> <span class="s2">&quot;mAP&quot;</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">greater_component_is_better</span> <span class="o">=</span> <span class="p">{</span>
+            <span class="sa">f</span><span class="s2">&quot;Precision</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
+            <span class="sa">f</span><span class="s2">&quot;Recall</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
+            <span class="sa">f</span><span class="s2">&quot;mAP</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
+            <span class="sa">f</span><span class="s2">&quot;F1</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
+        <span class="p">}</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">component_names</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">greater_component_is_better</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">components</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">component_names</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">components</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">component_names</span><span class="p">)</span>
+
         <span class="bp">self</span><span class="o">.</span><span class="n">post_prediction_callback</span> <span class="o">=</span> <span class="n">post_prediction_callback</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">post_prediction_callback</span> <span class="o">=</span> <span class="n">post_prediction_callback</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">is_distributed</span> <span class="o">=</span> <span class="n">super_gradients</span><span class="o">.</span><span class="n">is_distributed</span><span class="p">()</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">is_distributed</span> <span class="o">=</span> <span class="n">super_gradients</span><span class="o">.</span><span class="n">is_distributed</span><span class="p">()</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">denormalize_targets</span> <span class="o">=</span> <span class="ow">not</span> <span class="n">normalize_targets</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">denormalize_targets</span> <span class="o">=</span> <span class="ow">not</span> <span class="n">normalize_targets</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">rank</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">rank</span> <span class="o">=</span> <span class="kc">None</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">add_state</span><span class="p">(</span><span class="s2">&quot;matching_info&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="p">[],</span> <span class="n">dist_reduce_fx</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">add_state</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;matching_info</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="p">[],</span> <span class="n"
 
 
-        <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresholds</span> <span class="o">=</span> <span class="n">iou_thres</span><span class="o">.</span><span class="n">to_tensor</span><span class="p">()</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">recall_thresholds</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">101</span><span class="p">)</span> <span class="k">if</span> <span class="n">recall_thres</span> <span class="ow">is</span> <span class="kc">None</sp
         <span class="bp">self</span><span class="o">.</span><span class="n">recall_thresholds</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">101</span><span class="p">)</span> <span class="k">if</span> <span class="n">recall_thres</span> <span class="ow">is</span> <span class="kc">None</sp
         <span class="bp">self</span><span class="o">.</span><span class="n">score_threshold</span> <span class="o">=</span> <span class="n">score_thres</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">score_threshold</span> <span class="o">=</span> <span class="n">score_thres</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">top_k_predictions</span> <span class="o">=</span> <span class="n">top_k_predictions</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">top_k_predictions</span> <span class="o">=</span> <span class="n">top_k_predictions</span>
 
 
         <span class="bp">self</span><span class="o">.</span><span class="n">accumulate_on_cpu</span> <span class="o">=</span> <span class="n">accumulate_on_cpu</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">accumulate_on_cpu</span> <span class="o">=</span> <span class="n">accumulate_on_cpu</span>
 
 
-<div class="viewcode-block" id="DetectionMetrics.update"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.DetectionMetrics.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</s
-               <span class="n">inputs</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">,</span> <span class="n">crowd_targets</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
+<div class="viewcode-block" id="DetectionMetrics.update"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.DetectionMetrics.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><spa
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        Apply NMS and match all the predictions and targets of a given batch, and update the metric state accordingly.</span>
 <span class="sd">        Apply NMS and match all the predictions and targets of a given batch, and update the metric state accordingly.</span>
 
 
@@ -173,27 +191,38 @@
         <span class="n">preds</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">post_prediction_callback</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
         <span class="n">preds</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">post_prediction_callback</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
 
 
         <span class="n">new_matching_info</span> <span class="o">=</span> <span class="n">compute_detection_matching</span><span class="p">(</span>
         <span class="n">new_matching_info</span> <span class="o">=</span> <span class="n">compute_detection_matching</span><span class="p">(</span>
-            <span class="n">preds</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresholds</span><span class="p">,</span> <span class="n">crowd_targets</span><span class="o">=</span><span class="n">crowd_targets</span><span class="p">,</span>
-            <span class="n">top_k</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">top_k_predictions</span><span class="p">,</span> <span class="n">denormalize_targets</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">denormalize_targets</span><span class="p">,</span>
-            <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">return_on_cpu</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">accumulate_on_cpu</span><span class="p">)</span>
-
-        <span class="n">accumulated_matching_info</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s2">&quot;matching_info&quot;</span><span class="p">)</span>
-        <span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s2">&quot;matching_info&quot;</span><span class="p">,</span> <span class="n">accumulated_matching_info</span> <span class="o">+</span> <span class="n">new_matching_info</span><span class="p">)</span></div>
-
-<div class="viewcode-block" id="DetectionMetrics.compute"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.DetectionMetrics.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Unio
+            <span class="n">preds</span><span class="p">,</span>
+            <span class="n">targets</span><span class="p">,</span>
+            <span class="n">height</span><span class="p">,</span>
+            <span class="n">width</span><span class="p">,</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresholds</span><span class="p">,</span>
+            <span class="n">crowd_targets</span><span class="o">=</span><span class="n">crowd_targets</span><span class="p">,</span>
+            <span class="n">top_k</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">top_k_predictions</span><span class="p">,</span>
+            <span class="n">denormalize_targets</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">denormalize_targets</span><span class="p">,</span>
+            <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
+            <span class="n">return_on_cpu</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">accumulate_on_cpu</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="n">accumulated_matching_info</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;matching_info</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
+        <span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;matching_info</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">accumulated_matching_info</span> <span class="o">+</span> <span class="n">new_ma
+
+<div class="viewcode-block" id="DetectionMetrics.compute"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.DetectionMetrics.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Union</span>
         <span class="sd">&quot;&quot;&quot;Compute the metrics for all the accumulated results.</span>
         <span class="sd">&quot;&quot;&quot;Compute the metrics for all the accumulated results.</span>
-<span class="sd">            :return: Metrics of interest</span>
+<span class="sd">        :return: Metrics of interest</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
-        <span class="n">mean_ap</span><span class="p">,</span> <span class="n">mean_precision</span><span class="p">,</span> <span class="n">mean_recall</span><span class="p">,</span> <span class="n">mean_f1</span> <span class="o">=</span> <span class="o">-</span><span class="mf">1.</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.</span><span class="p">,</span> <span class="o">-</span><span clas
-        <span class="n">accumulated_matching_info</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s2">&quot;matching_info&quot;</span><span class="p">)</span>
+        <span class="n">mean_ap</span><span class="p">,</span> <span class="n">mean_precision</span><span class="p">,</span> <span class="n">mean_recall</span><span class="p">,</span> <span class="n">mean_f1</span> <span class="o">=</span> <span class="o">-</span><span class="mf">1.0</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.0</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.0</span><span class="p">,</span> <span class="o">-</span><span c
+        <span class="n">accumulated_matching_info</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;matching_info</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
 
 
         <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">accumulated_matching_info</span><span class="p">):</span>
         <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">accumulated_matching_info</span><span class="p">):</span>
             <span class="n">matching_info_tensors</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span clas
             <span class="n">matching_info_tensors</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span clas
 
 
             <span class="c1"># shape (n_class, nb_iou_thresh)</span>
             <span class="c1"># shape (n_class, nb_iou_thresh)</span>
             <span class="n">ap</span><span class="p">,</span> <span class="n">precision</span><span class="p">,</span> <span class="n">recall</span><span class="p">,</span> <span class="n">f1</span><span class="p">,</span> <span class="n">unique_classes</span> <span class="o">=</span> <span class="n">compute_detection_metrics</span><span class="p">(</span>
             <span class="n">ap</span><span class="p">,</span> <span class="n">precision</span><span class="p">,</span> <span class="n">recall</span><span class="p">,</span> <span class="n">f1</span><span class="p">,</span> <span class="n">unique_classes</span> <span class="o">=</span> <span class="n">compute_detection_metrics</span><span class="p">(</span>
-                <span class="o">*</span><span class="n">matching_info_tensors</span><span class="p">,</span> <span class="n">recall_thresholds</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">recall_thresholds</span><span class="p">,</span> <span class="n">score_threshold</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">score_threshold</span><span class="p">,</span>
-                <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cpu&quot;</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">accumulate_on_cpu</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
+                <span class="o">*</span><span class="n">matching_info_tensors</span><span class="p">,</span>
+                <span class="n">recall_thresholds</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">recall_thresholds</span><span class="p">,</span>
+                <span class="n">score_threshold</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">score_threshold</span><span class="p">,</span>
+                <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cpu&quot;</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">accumulate_on_cpu</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
+            <span class="p">)</span>
 
 
             <span class="c1"># Precision, recall and f1 are computed for smallest IoU threshold (usually 0.5), averaged over classes</span>
             <span class="c1"># Precision, recall and f1 are computed for smallest IoU threshold (usually 0.5), averaged over classes</span>
             <span class="n">mean_precision</span><span class="p">,</span> <span class="n">mean_recall</span><span class="p">,</span> <span class="n">mean_f1</span> <span class="o">=</span> <span class="n">precision</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">mean</span><span class="p">(),</span> <span class="n">recall</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</
             <span class="n">mean_precision</span><span class="p">,</span> <span class="n">mean_recall</span><span class="p">,</span> <span class="n">mean_f1</span> <span class="o">=</span> <span class="n">precision</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">mean</span><span class="p">(),</span> <span class="n">recall</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</
@@ -201,7 +230,12 @@
             <span class="c1"># MaP is averaged over IoU thresholds and over classes</span>
             <span class="c1"># MaP is averaged over IoU thresholds and over classes</span>
             <span class="n">mean_ap</span> <span class="o">=</span> <span class="n">ap</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
             <span class="n">mean_ap</span> <span class="o">=</span> <span class="n">ap</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
 
 
-        <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;Precision&quot;</span><span class="p">:</span> <span class="n">mean_precision</span><span class="p">,</span> <span class="s2">&quot;Recall&quot;</span><span class="p">:</span> <span class="n">mean_recall</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">map_str</span><span class="p">:</span> <span class="n">mean_ap</span><span class="p">,</span> <span class="s2">
+        <span class="k">return</span> <span class="p">{</span>
+            <span class="sa">f</span><span class="s2">&quot;Precision</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">:</span> <span class="n">mean_precision</span><span class="p">,</span>
+            <span class="sa">f</span><span class="s2">&quot;Recall</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">:</span> <span class="n">mean_recall</span><span class="p">,</span>
+            <span class="sa">f</span><span class="s2">&quot;mAP</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">:</span> <span class="n">mean_ap</span><span class="p">,</span>
+            <span class="sa">f</span><span class="s2">&quot;F1</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">:</span> <span class="n">mean_f1</span><span class="p">,</span>
+        <span class="p">}</span></div>
 
 
     <span class="k">def</span> <span class="nf">_sync_dist</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dist_sync_fn</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">process_group</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">_sync_dist</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dist_sync_fn</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">process_group</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
@@ -223,10 +257,83 @@
             <span class="n">torch</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">all_gather_object</span><span class="p">(</span><span class="n">gathered_state_dicts</span><span class="p">,</span> <span class="n">local_state_dict</span><span class="p">)</span>
             <span class="n">torch</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">all_gather_object</span><span class="p">(</span><span class="n">gathered_state_dicts</span><span class="p">,</span> <span class="n">local_state_dict</span><span class="p">)</span>
             <span class="n">matching_info</span> <span class="o">=</span> <span class="p">[]</span>
             <span class="n">matching_info</span> <span class="o">=</span> <span class="p">[]</span>
             <span class="k">for</span> <span class="n">state_dict</span> <span class="ow">in</span> <span class="n">gathered_state_dicts</span><span class="p">:</span>
             <span class="k">for</span> <span class="n">state_dict</span> <span class="ow">in</span> <span class="n">gathered_state_dicts</span><span class="p">:</span>
-                <span class="n">matching_info</span> <span class="o">+=</span> <span class="n">state_dict</span><span class="p">[</span><span class="s2">&quot;matching_info&quot;</span><span class="p">]</span>
+                <span class="n">matching_info</span> <span class="o">+=</span> <span class="n">state_dict</span><span class="p">[</span><span class="sa">f</span><span class="s2">&quot;matching_info</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">]</span>
             <span class="n">matching_info</span> <span class="o">=</span> <span class="n">tensor_container_to_device</span><span class="p">(</span><span class="n">matching_info</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cpu&quot;</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">accumulate_on_cpu</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span
             <span class="n">matching_info</span> <span class="o">=</span> <span class="n">tensor_container_to_device</span><span class="p">(</span><span class="n">matching_info</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cpu&quot;</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">accumulate_on_cpu</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span
 
 
-            <span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s2">&quot;matching_info&quot;</span><span class="p">,</span> <span class="n">matching_info</span><span class="p">)</span></div>
+            <span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;matching_info</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_range_str</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">matching_info</span><span class="p">)</span>
+
+    <span class="k">def</span> <span class="nf">_get_range_str</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+        <span class="k">return</span> <span class="s2">&quot;@</span><span class="si">%.2f</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresholds</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">if</span> <span class="ow">not</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">iou_thresh
+
+
+<div class="viewcode-block" id="DetectionMetrics_050"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.DetectionMetrics_050">[docs]</a><span class="k">class</span> <span class="nc">DetectionMetrics_050</span><span class="p">(</span><span class="n">DetectionMetrics</span><span class="p">):</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">num_cls</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
+        <span class="n">post_prediction_callback</span><span class="p">:</span> <span class="n">DetectionPostPredictionCallback</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">normalize_targets</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+        <span class="n">recall_thres</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">score_thres</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
+        <span class="n">top_k_predictions</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
+        <span class="n">dist_sync_on_step</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+        <span class="n">accumulate_on_cpu</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
+    <span class="p">):</span>
+
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
+            <span class="n">num_cls</span><span class="p">,</span>
+            <span class="n">post_prediction_callback</span><span class="p">,</span>
+            <span class="n">normalize_targets</span><span class="p">,</span>
+            <span class="n">IouThreshold</span><span class="o">.</span><span class="n">MAP_05</span><span class="p">,</span>
+            <span class="n">recall_thres</span><span class="p">,</span>
+            <span class="n">score_thres</span><span class="p">,</span>
+            <span class="n">top_k_predictions</span><span class="p">,</span>
+            <span class="n">dist_sync_on_step</span><span class="p">,</span>
+            <span class="n">accumulate_on_cpu</span><span class="p">,</span>
+        <span class="p">)</span></div>
+
+
+<div class="viewcode-block" id="DetectionMetrics_075"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.DetectionMetrics_075">[docs]</a><span class="k">class</span> <span class="nc">DetectionMetrics_075</span><span class="p">(</span><span class="n">DetectionMetrics</span><span class="p">):</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">num_cls</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
+        <span class="n">post_prediction_callback</span><span class="p">:</span> <span class="n">DetectionPostPredictionCallback</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">normalize_targets</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+        <span class="n">recall_thres</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">score_thres</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
+        <span class="n">top_k_predictions</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
+        <span class="n">dist_sync_on_step</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+        <span class="n">accumulate_on_cpu</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
+    <span class="p">):</span>
+
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
+            <span class="n">num_cls</span><span class="p">,</span> <span class="n">post_prediction_callback</span><span class="p">,</span> <span class="n">normalize_targets</span><span class="p">,</span> <span class="mf">0.75</span><span class="p">,</span> <span class="n">recall_thres</span><span class="p">,</span> <span class="n">score_thres</span><span class="p">,</span> <span class="n">top_k_predictions</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="p">,<
+        <span class="p">)</span></div>
+
+
+<div class="viewcode-block" id="DetectionMetrics_050_095"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.DetectionMetrics_050_095">[docs]</a><span class="k">class</span> <span class="nc">DetectionMetrics_050_095</span><span class="p">(</span><span class="n">DetectionMetrics</span><span class="p">):</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">num_cls</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
+        <span class="n">post_prediction_callback</span><span class="p">:</span> <span class="n">DetectionPostPredictionCallback</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">normalize_targets</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+        <span class="n">recall_thres</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">score_thres</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
+        <span class="n">top_k_predictions</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
+        <span class="n">dist_sync_on_step</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+        <span class="n">accumulate_on_cpu</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
+    <span class="p">):</span>
+
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
+            <span class="n">num_cls</span><span class="p">,</span>
+            <span class="n">post_prediction_callback</span><span class="p">,</span>
+            <span class="n">normalize_targets</span><span class="p">,</span>
+            <span class="n">IouThreshold</span><span class="o">.</span><span class="n">MAP_05_TO_095</span><span class="p">,</span>
+            <span class="n">recall_thres</span><span class="p">,</span>
+            <span class="n">score_thres</span><span class="p">,</span>
+            <span class="n">top_k_predictions</span><span class="p">,</span>
+            <span class="n">dist_sync_on_step</span><span class="p">,</span>
+            <span class="n">accumulate_on_cpu</span><span class="p">,</span>
+        <span class="p">)</span></div>
 </pre></div>
 </pre></div>
 
 
            </div>
            </div>
@@ -256,4 +363,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.metrics.metric_utils &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.metrics.metric_utils</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.metrics.metric_utils</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">torch</span>
  84. <span class="kn">from</span> <span class="nn">torchmetrics</span> <span class="kn">import</span> <span class="n">MetricCollection</span>
  85. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.utils</span> <span class="kn">import</span> <span class="n">AverageMeter</span>
  86. <div class="viewcode-block" id="get_logging_values"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.metric_utils.get_logging_values">[docs]</a><span class="k">def</span> <span class="nf">get_logging_values</span><span class="p">(</span><span class="n">loss_loggings</span><span class="p">:</span> <span class="n">AverageMeter</span><span class="p">,</span> <span class="n">metrics</span><span class="p">:</span> <span class="n">MetricCollection</span><span class="p">,</span> <span class="n">criterion</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  87. <span class="sd">&quot;&quot;&quot;</span>
  88. <span class="sd"> @param loss_loggings: AverageMeter running average for the loss items</span>
  89. <span class="sd"> @param metrics: MetricCollection object for running user specified metrics</span>
  90. <span class="sd"> @param criterion the object loss_loggings average meter is monitoring, when set to None- only the metrics values are</span>
  91. <span class="sd"> computed and returned.</span>
  92. <span class="sd"> @return: tuple of the computed values</span>
  93. <span class="sd"> &quot;&quot;&quot;</span>
  94. <span class="k">if</span> <span class="n">criterion</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  95. <span class="n">loss_loggingg_avg</span> <span class="o">=</span> <span class="n">loss_loggings</span><span class="o">.</span><span class="n">average</span>
  96. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">loss_loggingg_avg</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>
  97. <span class="n">loss_loggingg_avg</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">([</span><span class="n">loss_loggingg_avg</span><span class="p">])</span>
  98. <span class="n">logging_vals</span> <span class="o">=</span> <span class="n">loss_loggingg_avg</span> <span class="o">+</span> <span class="n">get_metrics_results_tuple</span><span class="p">(</span><span class="n">metrics</span><span class="p">)</span>
  99. <span class="k">else</span><span class="p">:</span>
  100. <span class="n">logging_vals</span> <span class="o">=</span> <span class="n">get_metrics_results_tuple</span><span class="p">(</span><span class="n">metrics</span><span class="p">)</span>
  101. <span class="k">return</span> <span class="n">logging_vals</span></div>
  102. <div class="viewcode-block" id="get_metrics_titles"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.metric_utils.get_metrics_titles">[docs]</a><span class="k">def</span> <span class="nf">get_metrics_titles</span><span class="p">(</span><span class="n">metrics_collection</span><span class="p">:</span> <span class="n">MetricCollection</span><span class="p">):</span>
  103. <span class="sd">&quot;&quot;&quot;</span>
  104. <span class="sd"> @param metrics_collection: MetricCollection object for running user specified metrics</span>
  105. <span class="sd"> @return: list of all the names of the computed values list(str)</span>
  106. <span class="sd"> &quot;&quot;&quot;</span>
  107. <span class="n">titles</span> <span class="o">=</span> <span class="p">[]</span>
  108. <span class="k">for</span> <span class="n">metric_name</span><span class="p">,</span> <span class="n">metric</span> <span class="ow">in</span> <span class="n">metrics_collection</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
  109. <span class="k">if</span> <span class="n">metric_name</span> <span class="o">==</span> <span class="s2">&quot;additional_items&quot;</span><span class="p">:</span>
  110. <span class="k">continue</span>
  111. <span class="k">elif</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">metric</span><span class="p">,</span> <span class="s2">&quot;component_names&quot;</span><span class="p">):</span>
  112. <span class="n">titles</span> <span class="o">+=</span> <span class="n">metric</span><span class="o">.</span><span class="n">component_names</span>
  113. <span class="k">else</span><span class="p">:</span>
  114. <span class="n">titles</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">metric_name</span><span class="p">)</span>
  115. <span class="k">return</span> <span class="n">titles</span></div>
  116. <div class="viewcode-block" id="get_metrics_results_tuple"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.metric_utils.get_metrics_results_tuple">[docs]</a><span class="k">def</span> <span class="nf">get_metrics_results_tuple</span><span class="p">(</span><span class="n">metrics_collection</span><span class="p">:</span> <span class="n">MetricCollection</span><span class="p">):</span>
  117. <span class="sd">&quot;&quot;&quot;</span>
  118. <span class="sd"> @param metrics_collection: metrics collection of the user specified metrics</span>
  119. <span class="sd"> @type metrics_collection</span>
  120. <span class="sd"> @return: tuple of metrics values</span>
  121. <span class="sd"> &quot;&quot;&quot;</span>
  122. <span class="k">if</span> <span class="n">metrics_collection</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  123. <span class="n">results_tuple</span> <span class="o">=</span> <span class="p">()</span>
  124. <span class="k">else</span><span class="p">:</span>
  125. <span class="n">results_tuple</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">flatten_metrics_dict</span><span class="p">(</span><span class="n">metrics_collection</span><span class="o">.</span><span class="n">compute</span><span class="p">())</span><span class="o">.</span><span class="n">values</span><span class="p">())</span>
  126. <span class="k">return</span> <span class="n">results_tuple</span></div>
  127. <div class="viewcode-block" id="flatten_metrics_dict"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.metric_utils.flatten_metrics_dict">[docs]</a><span class="k">def</span> <span class="nf">flatten_metrics_dict</span><span class="p">(</span><span class="n">metrics_dict</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
  128. <span class="sd">&quot;&quot;&quot;</span>
  129. <span class="sd"> :param metrics_dict - dictionary of metric values where values can also be dictionaries containing subvalues</span>
  130. <span class="sd"> (in the case of compound metrics)</span>
  131. <span class="sd"> @return: flattened dict of metric values i.e {metric1_name: metric1_value...}</span>
  132. <span class="sd"> &quot;&quot;&quot;</span>
  133. <span class="n">flattened</span> <span class="o">=</span> <span class="p">{}</span>
  134. <span class="k">for</span> <span class="n">metric_name</span><span class="p">,</span> <span class="n">metric_val</span> <span class="ow">in</span> <span class="n">metrics_dict</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
  135. <span class="k">if</span> <span class="n">metric_name</span> <span class="o">==</span> <span class="s2">&quot;additional_items&quot;</span><span class="p">:</span>
  136. <span class="k">continue</span>
  137. <span class="c1"># COLLECT ALL OF THE COMPONENTS IN THE CASE OF COMPOUND METRICS</span>
  138. <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">metric_val</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
  139. <span class="k">for</span> <span class="n">sub_metric_name</span><span class="p">,</span> <span class="n">sub_metric_val</span> <span class="ow">in</span> <span class="n">metric_val</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
  140. <span class="n">flattened</span><span class="p">[</span><span class="n">sub_metric_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">sub_metric_val</span>
  141. <span class="k">else</span><span class="p">:</span>
  142. <span class="n">flattened</span><span class="p">[</span><span class="n">metric_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">metric_val</span>
  143. <span class="k">return</span> <span class="n">flattened</span></div>
  144. <div class="viewcode-block" id="get_metrics_dict"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.metric_utils.get_metrics_dict">[docs]</a><span class="k">def</span> <span class="nf">get_metrics_dict</span><span class="p">(</span><span class="n">metrics_tuple</span><span class="p">,</span> <span class="n">metrics_collection</span><span class="p">,</span> <span class="n">loss_logging_item_names</span><span class="p">):</span>
  145. <span class="sd">&quot;&quot;&quot;</span>
  146. <span class="sd"> Returns a dictionary with the epoch results as values and their names as keys.</span>
  147. <span class="sd"> @param metrics_tuple: the result tuple</span>
  148. <span class="sd"> @param metrics_collection: MetricsCollection</span>
  149. <span class="sd"> @param loss_logging_item_names: loss component&#39;s names.</span>
  150. <span class="sd"> @return: dict</span>
  151. <span class="sd"> &quot;&quot;&quot;</span>
  152. <span class="n">keys</span> <span class="o">=</span> <span class="n">loss_logging_item_names</span> <span class="o">+</span> <span class="n">get_metrics_titles</span><span class="p">(</span><span class="n">metrics_collection</span><span class="p">)</span>
  153. <span class="n">metrics_dict</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">keys</span><span class="p">,</span> <span class="nb">list</span><span class="p">(</span><span class="n">metrics_tuple</span><span class="p">)))</span>
  154. <span class="k">return</span> <span class="n">metrics_dict</span></div>
  155. <div class="viewcode-block" id="get_train_loop_description_dict"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.metric_utils.get_train_loop_description_dict">[docs]</a><span class="k">def</span> <span class="nf">get_train_loop_description_dict</span><span class="p">(</span><span class="n">metrics_tuple</span><span class="p">,</span> <span class="n">metrics_collection</span><span class="p">,</span> <span class="n">loss_logging_item_names</span><span class="p">,</span> <span class="o">**</span><span class="n">log_items</span><span class="p">):</span>
  156. <span class="sd">&quot;&quot;&quot;</span>
  157. <span class="sd"> Returns a dictionary with the epoch&#39;s logging items as values and their names as keys, with the purpose of</span>
  158. <span class="sd"> passing it as a description to tqdm&#39;s progress bar.</span>
  159. <span class="sd"> @param metrics_tuple: the result tuple</span>
  160. <span class="sd"> @param metrics_collection: MetricsCollection</span>
  161. <span class="sd"> @param loss_logging_item_names: loss component&#39;s names.</span>
  162. <span class="sd"> @param log_items additional logging items to be rendered.</span>
  163. <span class="sd"> @return: dict</span>
  164. <span class="sd"> &quot;&quot;&quot;</span>
  165. <span class="n">log_items</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">get_metrics_dict</span><span class="p">(</span><span class="n">metrics_tuple</span><span class="p">,</span> <span class="n">metrics_collection</span><span class="p">,</span> <span class="n">loss_logging_item_names</span><span class="p">))</span>
  166. <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">log_items</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
  167. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
  168. <span class="n">log_items</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
  169. <span class="k">return</span> <span class="n">log_items</span></div>
  170. </pre></div>
  171. </div>
  172. </div>
  173. <footer>
  174. <hr/>
  175. <div role="contentinfo">
  176. <p>&#169; Copyright 2021, SuperGradients team.</p>
  177. </div>
  178. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  179. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  180. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  181. </footer>
  182. </div>
  183. </div>
  184. </section>
  185. </div>
  186. <script>
  187. jQuery(function () {
  188. SphinxRtdTheme.Navigation.enable(true);
  189. });
  190. </script>
  191. </body>
  192. </html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.metrics.segmentation_metrics &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.metrics.segmentation_metrics &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -95,7 +97,7 @@
 <span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">ABC</span><span class="p">,</span> <span class="n">abstractmethod</span>
 <span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">ABC</span><span class="p">,</span> <span class="n">abstractmethod</span>
 
 
 
 
-<div class="viewcode-block" id="batch_pix_accuracy"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.batch_pix_accuracy">[docs]</a><span class="k">def</span> <span class="nf">batch_pix_accuracy</span><span class="p">(</span><span class="n">predict</span><span class="p">,</span> <span class="n">target</span><span class="p">):</span>
+<span class="k">def</span> <span class="nf">batch_pix_accuracy</span><span class="p">(</span><span class="n">predict</span><span class="p">,</span> <span class="n">target</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;Batch Pixel Accuracy</span>
     <span class="sd">&quot;&quot;&quot;Batch Pixel Accuracy</span>
 <span class="sd">    Args:</span>
 <span class="sd">    Args:</span>
 <span class="sd">        predict: input 4D tensor</span>
 <span class="sd">        predict: input 4D tensor</span>
@@ -106,12 +108,11 @@
     <span class="n">target</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span>
     <span class="n">target</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="o">+</span> <span class="mi">1</span>
     <span class="n">pixel_labeled</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">target</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span>
     <span class="n">pixel_labeled</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">target</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span>
     <span class="n">pixel_correct</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">((</span><span class="n">predict</span> <span class="o">==</span> <span class="n">target</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">target</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">))</span>
     <span class="n">pixel_correct</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">((</span><span class="n">predict</span> <span class="o">==</span> <span class="n">target</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">target</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">))</span>
-    <span class="k">assert</span> <span class="n">pixel_correct</span> <span class="o">&lt;=</span> <span class="n">pixel_labeled</span><span class="p">,</span> \
-        <span class="s2">&quot;Correct area should be smaller than Labeled&quot;</span>
-    <span class="k">return</span> <span class="n">pixel_correct</span><span class="p">,</span> <span class="n">pixel_labeled</span></div>
+    <span class="k">assert</span> <span class="n">pixel_correct</span> <span class="o">&lt;=</span> <span class="n">pixel_labeled</span><span class="p">,</span> <span class="s2">&quot;Correct area should be smaller than Labeled&quot;</span>
+    <span class="k">return</span> <span class="n">pixel_correct</span><span class="p">,</span> <span class="n">pixel_labeled</span>
 
 
 
 
-<div class="viewcode-block" id="batch_intersection_union"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.batch_intersection_union">[docs]</a><span class="k">def</span> <span class="nf">batch_intersection_union</span><span class="p">(</span><span class="n">predict</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">nclass</span><span class="p">):</span>
+<span class="k">def</span> <span class="nf">batch_intersection_union</span><span class="p">(</span><span class="n">predict</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">nclass</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;Batch Intersection of Union</span>
     <span class="sd">&quot;&quot;&quot;Batch Intersection of Union</span>
 <span class="sd">    Args:</span>
 <span class="sd">    Args:</span>
 <span class="sd">        predict: input 4D tensor</span>
 <span class="sd">        predict: input 4D tensor</span>
@@ -132,13 +133,12 @@
     <span class="n">area_pred</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">predict</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">nbins</span><span class="p">,</span> <span class="nb">range</span><span class="o">=</span><span class="p">(</span><span class="n">mini</span><span class="p">,</s
     <span class="n">area_pred</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">predict</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">nbins</span><span class="p">,</span> <span class="nb">range</span><span class="o">=</span><span class="p">(</span><span class="n">mini</span><span class="p">,</s
     <span class="n">area_lab</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">target</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">nbins</span><span class="p">,</span> <span class="nb">range</span><span class="o">=</span><span class="p">(</span><span class="n">mini</span><span class="p">,</spa
     <span class="n">area_lab</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">target</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">nbins</span><span class="p">,</span> <span class="nb">range</span><span class="o">=</span><span class="p">(</span><span class="n">mini</span><span class="p">,</spa
     <span class="n">area_union</span> <span class="o">=</span> <span class="n">area_pred</span> <span class="o">+</span> <span class="n">area_lab</span> <span class="o">-</span> <span class="n">area_inter</span>
     <span class="n">area_union</span> <span class="o">=</span> <span class="n">area_pred</span> <span class="o">+</span> <span class="n">area_lab</span> <span class="o">-</span> <span class="n">area_inter</span>
-    <span class="k">assert</span> <span class="p">(</span><span class="n">area_inter</span> <span class="o">&lt;=</span> <span class="n">area_union</span><span class="p">)</span><span class="o">.</span><span class="n">all</span><span class="p">(),</span> \
-        <span class="s2">&quot;Intersection area should be smaller than Union area&quot;</span>
-    <span class="k">return</span> <span class="n">area_inter</span><span class="p">,</span> <span class="n">area_union</span></div>
+    <span class="k">assert</span> <span class="p">(</span><span class="n">area_inter</span> <span class="o">&lt;=</span> <span class="n">area_union</span><span class="p">)</span><span class="o">.</span><span class="n">all</span><span class="p">(),</span> <span class="s2">&quot;Intersection area should be smaller than Union area&quot;</span>
+    <span class="k">return</span> <span class="n">area_inter</span><span class="p">,</span> <span class="n">area_union</span>
 
 
 
 
 <span class="c1"># ref https://github.com/CSAILVision/sceneparsing/blob/master/evaluationCode/utils_eval.py</span>
 <span class="c1"># ref https://github.com/CSAILVision/sceneparsing/blob/master/evaluationCode/utils_eval.py</span>
-<div class="viewcode-block" id="pixel_accuracy"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.pixel_accuracy">[docs]</a><span class="k">def</span> <span class="nf">pixel_accuracy</span><span class="p">(</span><span class="n">im_pred</span><span class="p">,</span> <span class="n">im_lab</span><span class="p">):</span>
+<span class="k">def</span> <span class="nf">pixel_accuracy</span><span class="p">(</span><span class="n">im_pred</span><span class="p">,</span> <span class="n">im_lab</span><span class="p">):</span>
     <span class="n">im_pred</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">im_pred</span><span class="p">)</span>
     <span class="n">im_pred</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">im_pred</span><span class="p">)</span>
     <span class="n">im_lab</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">im_lab</span><span class="p">)</span>
     <span class="n">im_lab</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">im_lab</span><span class="p">)</span>
 
 
@@ -147,7 +147,7 @@
     <span class="n">pixel_labeled</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">im_lab</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span>
     <span class="n">pixel_labeled</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">im_lab</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span>
     <span class="n">pixel_correct</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">((</span><span class="n">im_pred</span> <span class="o">==</span> <span class="n">im_lab</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">im_lab</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">))</span>
     <span class="n">pixel_correct</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">((</span><span class="n">im_pred</span> <span class="o">==</span> <span class="n">im_lab</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">im_lab</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">))</span>
     <span class="c1"># pixel_accuracy = 1.0 * pixel_correct / pixel_labeled</span>
     <span class="c1"># pixel_accuracy = 1.0 * pixel_correct / pixel_labeled</span>
-    <span class="k">return</span> <span class="n">pixel_correct</span><span class="p">,</span> <span class="n">pixel_labeled</span></div>
+    <span class="k">return</span> <span class="n">pixel_correct</span><span class="p">,</span> <span class="n">pixel_labeled</span>
 
 
 
 
 <span class="k">def</span> <span class="nf">_dice_from_confmat</span><span class="p">(</span>
 <span class="k">def</span> <span class="nf">_dice_from_confmat</span><span class="p">(</span>
@@ -188,51 +188,48 @@
         <span class="n">scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
         <span class="n">scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span>
             <span class="p">[</span>
             <span class="p">[</span>
                 <span class="n">scores</span><span class="p">[:</span><span class="n">ignore_index</span><span class="p">],</span>
                 <span class="n">scores</span><span class="p">[:</span><span class="n">ignore_index</span><span class="p">],</span>
-                <span class="n">scores</span><span class="p">[</span><span class="n">ignore_index</span> <span class="o">+</span> <span class="mi">1</span><span class="p">:],</span>
+                <span class="n">scores</span><span class="p">[</span><span class="n">ignore_index</span> <span class="o">+</span> <span class="mi">1</span> <span class="p">:],</span>
             <span class="p">]</span>
             <span class="p">]</span>
         <span class="p">)</span>
         <span class="p">)</span>
 
 
     <span class="k">return</span> <span class="n">reduce</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="n">reduction</span><span class="p">)</span>
     <span class="k">return</span> <span class="n">reduce</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="n">reduction</span><span class="p">)</span>
 
 
 
 
-<div class="viewcode-block" id="intersection_and_union"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.intersection_and_union">[docs]</a><span class="k">def</span> <span class="nf">intersection_and_union</span><span class="p">(</span><span class="n">im_pred</span><span class="p">,</span> <span class="n">im_lab</span><span class="p">,</span> <span class="n">num_class</span><span class="p">):</span>
+<span class="k">def</span> <span class="nf">intersection_and_union</span><span class="p">(</span><span class="n">im_pred</span><span class="p">,</span> <span class="n">im_lab</span><span class="p">,</span> <span class="n">num_class</span><span class="p">):</span>
     <span class="n">im_pred</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">im_pred</span><span class="p">)</span>
     <span class="n">im_pred</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">im_pred</span><span class="p">)</span>
     <span class="n">im_lab</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">im_lab</span><span class="p">)</span>
     <span class="n">im_lab</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">im_lab</span><span class="p">)</span>
     <span class="c1"># Remove classes from unlabeled pixels in gt image.</span>
     <span class="c1"># Remove classes from unlabeled pixels in gt image.</span>
     <span class="n">im_pred</span> <span class="o">=</span> <span class="n">im_pred</span> <span class="o">*</span> <span class="p">(</span><span class="n">im_lab</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span>
     <span class="n">im_pred</span> <span class="o">=</span> <span class="n">im_pred</span> <span class="o">*</span> <span class="p">(</span><span class="n">im_lab</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span>
     <span class="c1"># Compute area intersection:</span>
     <span class="c1"># Compute area intersection:</span>
     <span class="n">intersection</span> <span class="o">=</span> <span class="n">im_pred</span> <span class="o">*</span> <span class="p">(</span><span class="n">im_pred</span> <span class="o">==</span> <span class="n">im_lab</span><span class="p">)</span>
     <span class="n">intersection</span> <span class="o">=</span> <span class="n">im_pred</span> <span class="o">*</span> <span class="p">(</span><span class="n">im_pred</span> <span class="o">==</span> <span class="n">im_lab</span><span class="p">)</span>
-    <span class="n">area_inter</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">intersection</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">num_class</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
-                                 <span class="nb">range</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_class</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))</span>
+    <span class="n">area_inter</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">intersection</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">num_class</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="nb">range</span><span class="o">=</span><span clas
     <span class="c1"># Compute area union:</span>
     <span class="c1"># Compute area union:</span>
-    <span class="n">area_pred</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">im_pred</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">num_class</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
-                                <span class="nb">range</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_class</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))</span>
-    <span class="n">area_lab</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">im_lab</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">num_class</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span>
-                               <span class="nb">range</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_class</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))</span>
+    <span class="n">area_pred</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">im_pred</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">num_class</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="nb">range</span><span class="o">=</span><span class="p">
+    <span class="n">area_lab</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">histogram</span><span class="p">(</span><span class="n">im_lab</span><span class="p">,</span> <span class="n">bins</span><span class="o">=</span><span class="n">num_class</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="nb">range</span><span class="o">=</span><span class="p">(<
     <span class="n">area_union</span> <span class="o">=</span> <span class="n">area_pred</span> <span class="o">+</span> <span class="n">area_lab</span> <span class="o">-</span> <span class="n">area_inter</span>
     <span class="n">area_union</span> <span class="o">=</span> <span class="n">area_pred</span> <span class="o">+</span> <span class="n">area_lab</span> <span class="o">-</span> <span class="n">area_inter</span>
-    <span class="k">return</span> <span class="n">area_inter</span><span class="p">,</span> <span class="n">area_union</span></div>
+    <span class="k">return</span> <span class="n">area_inter</span><span class="p">,</span> <span class="n">area_union</span>
 
 
 
 
-<div class="viewcode-block" id="AbstractMetricsArgsPrepFn"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.AbstractMetricsArgsPrepFn">[docs]</a><span class="k">class</span> <span class="nc">AbstractMetricsArgsPrepFn</span><span class="p">(</span><span class="n">ABC</span><span class="p">):</span>
+<span class="k">class</span> <span class="nc">AbstractMetricsArgsPrepFn</span><span class="p">(</span><span class="n">ABC</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Abstract preprocess metrics arguments class.</span>
 <span class="sd">    Abstract preprocess metrics arguments class.</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
+
     <span class="nd">@abstractmethod</span>
     <span class="nd">@abstractmethod</span>
     <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="
     <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        All base classes must implement this function and return a tuple of torch tensors (predictions, target).</span>
 <span class="sd">        All base classes must implement this function and return a tuple of torch tensors (predictions, target).</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
-        <span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">()</span></div>
+        <span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">()</span>
 
 
 
 
-<div class="viewcode-block" id="PreprocessSegmentationMetricsArgs"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.PreprocessSegmentationMetricsArgs">[docs]</a><span class="k">class</span> <span class="nc">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">):</span>
+<div class="viewcode-block" id="PreprocessSegmentationMetricsArgs"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.PreprocessSegmentationMetricsArgs">[docs]</a><span class="k">class</span> <span class="nc">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Default segmentation inputs preprocess function before updating segmentation metrics, handles multiple inputs and</span>
 <span class="sd">    Default segmentation inputs preprocess function before updating segmentation metrics, handles multiple inputs and</span>
 <span class="sd">    apply normalizations.</span>
 <span class="sd">    apply normalizations.</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
-    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
-                 <span class="n">apply_arg_max</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
-                 <span class="n">apply_sigmoid</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
+
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">apply_arg_max</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">apply_sigmoid</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        :param apply_arg_max: Whether to apply argmax on predictions tensor.</span>
 <span class="sd">        :param apply_arg_max: Whether to apply argmax on predictions tensor.</span>
 <span class="sd">        :param apply_sigmoid:  Whether to apply sigmoid on predictions tensor.</span>
 <span class="sd">        :param apply_sigmoid:  Whether to apply sigmoid on predictions tensor.</span>
@@ -253,18 +250,16 @@
         <span class="k">return</span> <span class="n">preds</span><span class="p">,</span> <span class="n">target</span></div>
         <span class="k">return</span> <span class="n">preds</span><span class="p">,</span> <span class="n">target</span></div>
 
 
 
 
-<div class="viewcode-block" id="PixelAccuracy"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.PixelAccuracy">[docs]</a><span class="k">class</span> <span class="nc">PixelAccuracy</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
-    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
-                 <span class="n">ignore_label</span><span class="o">=-</span><span class="mi">100</span><span class="p">,</span>
-                 <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
-                 <span class="n">metrics_args_prep_fn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
+<div class="viewcode-block" id="PixelAccuracy"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.PixelAccuracy">[docs]</a><span class="k">class</span> <span class="nc">PixelAccuracy</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ignore_label</span><span class="o">=-</span><span class="mi">100</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">metrics_args_prep_fn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span 
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">)</span>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">ignore_label</span> <span class="o">=</span> <span class="n">ignore_label</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">ignore_label</span> <span class="o">=</span> <span class="n">ignore_label</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">add_state</span><span class="p">(</span><span class="s2">&quot;total_correct&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">0.</span><span class="p">),</span> <span class="n">dist_reduce_fx</span><span class="o">=</span><span class="s2">&quot;sum&quot;</spa
-        <span class="bp">self</span><span class="o">.</span><span class="n">add_state</span><span class="p">(</span><span class="s2">&quot;total_label&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">0.</span><span class="p">),</span> <span class="n">dist_reduce_fx</span><span class="o">=</span><span class="s2">&quot;sum&quot;</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="o">=</span> <span class="kc">True</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">add_state</span><span class="p">(</span><span class="s2">&quot;total_correct&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">0.0</span><span class="p">),</span> <span class="n">dist_reduce_fx</span><span class="o">=</span><span class="s2">&quot;sum&quot;</sp
+        <span class="bp">self</span><span class="o">.</span><span class="n">add_state</span><span class="p">(</span><span class="s2">&quot;total_label&quot;</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">0.0</span><span class="p">),</span> <span class="n">dist_reduce_fx</span><span class="o">=</span><span class="s2">&quot;sum&quot;</span
         <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span> <span class="o">=</span> <span class="n">metrics_args_prep_fn</span> <span class="ow">or</span> <span class="n">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">apply_arg_max</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span> <span class="o">=</span> <span class="n">metrics_args_prep_fn</span> <span class="ow">or</span> <span class="n">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">apply_arg_max</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
 
 
-<div class="viewcode-block" id="PixelAccuracy.update"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.PixelAccuracy.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <s
+<div class="viewcode-block" id="PixelAccuracy.update"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.PixelAccuracy.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span clas
         <span class="n">predict</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
         <span class="n">predict</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
 
 
         <span class="n">labeled_mask</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">ne</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ignore_label</span><span class="p">)</span>
         <span class="n">labeled_mask</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">ne</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ignore_label</span><span class="p">)</span>
@@ -273,81 +268,117 @@
         <span class="bp">self</span><span class="o">.</span><span class="n">total_correct</span> <span class="o">+=</span> <span class="n">pixel_correct</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">total_correct</span> <span class="o">+=</span> <span class="n">pixel_correct</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">total_label</span> <span class="o">+=</span> <span class="n">pixel_labeled</span></div>
         <span class="bp">self</span><span class="o">.</span><span class="n">total_label</span> <span class="o">+=</span> <span class="n">pixel_labeled</span></div>
 
 
-<div class="viewcode-block" id="PixelAccuracy.compute"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.PixelAccuracy.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
-        <span class="n">_total_correct</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_correct</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s1">&#39;int64&#39;<
-        <span class="n">_total_label</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_label</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s1">&#39;int64&#39;</spa
+<div class="viewcode-block" id="PixelAccuracy.compute"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.PixelAccuracy.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+        <span class="n">_total_correct</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_correct</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">&quot;int64&quot
+        <span class="n">_total_label</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_label</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">&quot;int64&quot;</s
         <span class="n">pix_acc</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span> <span class="o">*</span> <span class="n">_total_correct</span> <span class="o">/</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">spacing</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n"
         <span class="n">pix_acc</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">float64</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span> <span class="o">*</span> <span class="n">_total_correct</span> <span class="o">/</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">spacing</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n"
         <span class="k">return</span> <span class="n">pix_acc</span></div></div>
         <span class="k">return</span> <span class="n">pix_acc</span></div></div>
 
 
 
 
-<div class="viewcode-block" id="IoU"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.IoU">[docs]</a><span class="k">class</span> <span class="nc">IoU</span><span class="p">(</span><span class="n">torchmetrics</span><span class="o">.</span><span class="n">JaccardIndex</span><span class="p">):</span>
-    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
-                 <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
-                 <span class="n">dist_sync_on_step</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
-                 <span class="n">ignore_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
-                 <span class="n">reduction</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;elementwise_mean&quot;</span><span class="p">,</span>
-                 <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
-                 <span class="n">metrics_args_prep_fn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
-        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">,
-                         <span class="n">reduction</span><span class="o">=</span><span class="n">reduction</span><span class="p">,</span> <span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">)</span>
+<div class="viewcode-block" id="IoU"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.IoU">[docs]</a><span class="k">class</span> <span class="nc">IoU</span><span class="p">(</span><span class="n">torchmetrics</span><span class="o">.</span><span class="n">JaccardIndex</span><span class="p">):</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
+        <span class="n">dist_sync_on_step</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+        <span class="n">ignore_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">reduction</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;elementwise_mean&quot;</span><span class="p">,</span>
+        <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
+        <span class="n">metrics_args_prep_fn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="p">):</span>
+
+        <span class="k">if</span> <span class="n">num_classes</span> <span class="o">&lt;=</span> <span class="mi">1</span><span class="p">:</span>
+            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;IoU class only for multi-class usage! For binary usage, please call </span><span class="si">{</span><span class="n">BinaryIOU</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
+
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">,
         <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span> <span class="o">=</span> <span class="n">metrics_args_prep_fn</span> <span class="ow">or</span> <span class="n">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">apply_arg_max</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span> <span class="o">=</span> <span class="n">metrics_args_prep_fn</span> <span class="ow">or</span> <span class="n">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">apply_arg_max</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="o">=</span> <span class="kc">True</span>
 
 
-<div class="viewcode-block" id="IoU.update"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.IoU.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor
+<div class="viewcode-block" id="IoU.update"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.IoU.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><
         <span class="n">preds</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
         <span class="n">preds</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">preds</span><span class="o">=</span><span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="o">=</span><span class="n">target</span><span class="p">)</span></div></div>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">preds</span><span class="o">=</span><span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="o">=</span><span class="n">target</span><span class="p">)</span></div></div>
 
 
 
 
-<div class="viewcode-block" id="Dice"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.Dice">[docs]</a><span class="k">class</span> <span class="nc">Dice</span><span class="p">(</span><span class="n">torchmetrics</span><span class="o">.</span><span class="n">JaccardIndex</span><span class="p">):</span>
-    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
-                 <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
-                 <span class="n">dist_sync_on_step</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
-                 <span class="n">ignore_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
-                 <span class="n">reduction</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;elementwise_mean&quot;</span><span class="p">,</span>
-                 <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
-                 <span class="n">metrics_args_prep_fn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
-        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">,
-                         <span class="n">reduction</span><span class="o">=</span><span class="n">reduction</span><span class="p">,</span> <span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">)</span>
+<div class="viewcode-block" id="Dice"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.Dice">[docs]</a><span class="k">class</span> <span class="nc">Dice</span><span class="p">(</span><span class="n">torchmetrics</span><span class="o">.</span><span class="n">JaccardIndex</span><span class="p">):</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
+        <span class="n">dist_sync_on_step</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+        <span class="n">ignore_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">reduction</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;elementwise_mean&quot;</span><span class="p">,</span>
+        <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
+        <span class="n">metrics_args_prep_fn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="p">):</span>
+
+        <span class="k">if</span> <span class="n">num_classes</span> <span class="o">&lt;=</span> <span class="mi">1</span><span class="p">:</span>
+            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Dice class only for multi-class usage! For binary usage, please call </span><span class="si">{</span><span class="n">BinaryDice</span><span class="o">.</span><span class="vm">__name__</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
+
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">,
         <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span> <span class="o">=</span> <span class="n">metrics_args_prep_fn</span> <span class="ow">or</span> <span class="n">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">apply_arg_max</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span> <span class="o">=</span> <span class="n">metrics_args_prep_fn</span> <span class="ow">or</span> <span class="n">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">apply_arg_max</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="o">=</span> <span class="kc">True</span>
 
 
-<div class="viewcode-block" id="Dice.update"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.Dice.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tens
+<div class="viewcode-block" id="Dice.update"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.Dice.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span
         <span class="n">preds</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
         <span class="n">preds</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics_args_prep_fn</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">preds</span><span class="o">=</span><span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="o">=</span><span class="n">target</span><span class="p">)</span></div>
         <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">preds</span><span class="o">=</span><span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="o">=</span><span class="n">target</span><span class="p">)</span></div>
 
 
-<div class="viewcode-block" id="Dice.compute"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.Dice.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
+<div class="viewcode-block" id="Dice.compute"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.Dice.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
         <span class="sd">&quot;&quot;&quot;Computes Dice coefficient&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;Computes Dice coefficient&quot;&quot;&quot;</span>
-        <span class="k">return</span> <span class="n">_dice_from_confmat</span><span class="p">(</span>
-            <span class="bp">self</span><span class="o">.</span><span class="n">confmat</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_classes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ignore_index</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">absent_score</span><span class="p">,</span> <span class="bp">self</span><span class="o">.
-        <span class="p">)</span></div></div>
+        <span class="k">return</span> <span class="n">_dice_from_confmat</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">confmat</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_classes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ignore_index</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n
 
 
 
 
-<div class="viewcode-block" id="BinaryIOU"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.BinaryIOU">[docs]</a><span class="k">class</span> <span class="nc">BinaryIOU</span><span class="p">(</span><span class="n">IoU</span><span class="p">):</span>
-    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
-                 <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
-                 <span class="n">ignore_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
-                 <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
-                 <span class="n">metrics_args_prep_fn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
+<div class="viewcode-block" id="BinaryIOU"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.BinaryIOU">[docs]</a><span class="k">class</span> <span class="nc">BinaryIOU</span><span class="p">(</span><span class="n">IoU</span><span class="p">):</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
+        <span class="n">ignore_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
+        <span class="n">metrics_args_prep_fn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="p">):</span>
         <span class="n">metrics_args_prep_fn</span> <span class="o">=</span> <span class="n">metrics_args_prep_fn</span> <span class="ow">or</span> <span class="n">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">apply_sigmoid</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
         <span class="n">metrics_args_prep_fn</span> <span class="o">=</span> <span class="n">metrics_args_prep_fn</span> <span class="ow">or</span> <span class="n">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">apply_sigmoid</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
-        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">num_classes</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">,</span>
-                         <span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">,</span> <span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">,</span> <span class="n">metrics_args_prep_fn</span><span class="o">=</span><span class="n">metrics_args_prep_fn</span><span class="p">)</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">component_names</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;target_IOU&quot;</span><span class="p">,</span> <span class="s2">&quot;background_IOU&quot;</span><span class="p">,</span> <span class="s2">&quot;mean_IOU&quot;</span><span class="p">]</span>
-
-<div class="viewcode-block" id="BinaryIOU.compute"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.BinaryIOU.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
+            <span class="n">num_classes</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span>
+            <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">,</span>
+            <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">,</span>
+            <span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">,</span>
+            <span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">,</span>
+            <span class="n">metrics_args_prep_fn</span><span class="o">=</span><span class="n">metrics_args_prep_fn</span><span class="p">,</span>
+        <span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">greater_component_is_better</span> <span class="o">=</span> <span class="p">{</span>
+            <span class="s2">&quot;target_IOU&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
+            <span class="s2">&quot;background_IOU&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
+            <span class="s2">&quot;mean_IOU&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
+        <span class="p">}</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">component_names</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">greater_component_is_better</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
+
+<div class="viewcode-block" id="BinaryIOU.compute"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.BinaryIOU.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
         <span class="n">ious</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">BinaryIOU</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">compute</span><span class="p">()</span>
         <span class="n">ious</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">BinaryIOU</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">compute</span><span class="p">()</span>
         <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;target_IOU&quot;</span><span class="p">:</span> <span class="n">ious</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="s2">&quot;background_IOU&quot;</span><span class="p">:</span> <span class="n">ious</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="s2">&quot;mean_IOU&quot;</span><span class="p">:</span> <span class="n">io
         <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;target_IOU&quot;</span><span class="p">:</span> <span class="n">ious</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="s2">&quot;background_IOU&quot;</span><span class="p">:</span> <span class="n">ious</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="s2">&quot;mean_IOU&quot;</span><span class="p">:</span> <span class="n">io
 
 
 
 
-<div class="viewcode-block" id="BinaryDice"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.BinaryDice">[docs]</a><span class="k">class</span> <span class="nc">BinaryDice</span><span class="p">(</span><span class="n">Dice</span><span class="p">):</span>
-    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
-                 <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
-                 <span class="n">ignore_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
-                 <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
-                 <span class="n">metrics_args_prep_fn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
+<div class="viewcode-block" id="BinaryDice"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.BinaryDice">[docs]</a><span class="k">class</span> <span class="nc">BinaryDice</span><span class="p">(</span><span class="n">Dice</span><span class="p">):</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
+        <span class="n">ignore_index</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
+        <span class="n">metrics_args_prep_fn</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">AbstractMetricsArgsPrepFn</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="p">):</span>
         <span class="n">metrics_args_prep_fn</span> <span class="o">=</span> <span class="n">metrics_args_prep_fn</span> <span class="ow">or</span> <span class="n">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">apply_sigmoid</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
         <span class="n">metrics_args_prep_fn</span> <span class="o">=</span> <span class="n">metrics_args_prep_fn</span> <span class="ow">or</span> <span class="n">PreprocessSegmentationMetricsArgs</span><span class="p">(</span><span class="n">apply_sigmoid</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
-        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">num_classes</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">,</span>
-                         <span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">,</span> <span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">,</span> <span class="n">metrics_args_prep_fn</span><span class="o">=</span><span class="n">metrics_args_prep_fn</span><span class="p">)</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">component_names</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;target_Dice&quot;</span><span class="p">,</span> <span class="s2">&quot;background_Dice&quot;</span><span class="p">,</span> <span class="s2">&quot;mean_Dice&quot;</span><span class="p">]</span>
-
-<div class="viewcode-block" id="BinaryDice.compute"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.BinaryDice.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
+            <span class="n">num_classes</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span>
+            <span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">,</span>
+            <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">,</span>
+            <span class="n">reduction</span><span class="o">=</span><span class="s2">&quot;none&quot;</span><span class="p">,</span>
+            <span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">,</span>
+            <span class="n">metrics_args_prep_fn</span><span class="o">=</span><span class="n">metrics_args_prep_fn</span><span class="p">,</span>
+        <span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">greater_component_is_better</span> <span class="o">=</span> <span class="p">{</span>
+            <span class="s2">&quot;target_Dice&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
+            <span class="s2">&quot;background_Dice&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
+            <span class="s2">&quot;mean_Dice&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
+        <span class="p">}</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">component_names</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">greater_component_is_better</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
+
+<div class="viewcode-block" id="BinaryDice.compute"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.metrics.BinaryDice.compute">[docs]</a>    <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
         <span class="n">dices</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">compute</span><span class="p">()</span>
         <span class="n">dices</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">compute</span><span class="p">()</span>
         <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;target_Dice&quot;</span><span class="p">:</span> <span class="n">dices</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="s2">&quot;background_Dice&quot;</span><span class="p">:</span> <span class="n">dices</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="s2">&quot;mean_Dice&quot;</span><span class="p">:</span> <span class="
         <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;target_Dice&quot;</span><span class="p">:</span> <span class="n">dices</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="s2">&quot;background_Dice&quot;</span><span class="p">:</span> <span class="n">dices</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="s2">&quot;mean_Dice&quot;</span><span class="p">:</span> <span class="
 </pre></div>
 </pre></div>
@@ -379,4 +410,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.models.sg_module &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.models.sg_module</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.models.sg_module</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span>
  84. <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
  85. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.utils</span> <span class="kn">import</span> <span class="n">HpmStruct</span>
  86. <div class="viewcode-block" id="SgModule"><a class="viewcode-back" href="../../../../super_gradients.training.models.html#super_gradients.training.models.sg_module.SgModule">[docs]</a><span class="k">class</span> <span class="nc">SgModule</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
  87. <div class="viewcode-block" id="SgModule.initialize_param_groups"><a class="viewcode-back" href="../../../../super_gradients.training.models.html#super_gradients.training.models.sg_module.SgModule.initialize_param_groups">[docs]</a> <span class="k">def</span> <span class="nf">initialize_param_groups</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lr</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">training_params</span><span class="p">:</span> <span class="n">HpmStruct</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">:</span>
  88. <span class="sd">&quot;&quot;&quot;</span>
  89. <span class="sd"> :return: list of dictionaries containing the key &#39;named_params&#39; with a list of named params</span>
  90. <span class="sd"> &quot;&quot;&quot;</span>
  91. <span class="k">return</span> <span class="p">[{</span><span class="s2">&quot;named_params&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">named_parameters</span><span class="p">()}]</span></div>
  92. <div class="viewcode-block" id="SgModule.update_param_groups"><a class="viewcode-back" href="../../../../super_gradients.training.models.html#super_gradients.training.models.sg_module.SgModule.update_param_groups">[docs]</a> <span class="k">def</span> <span class="nf">update_param_groups</span><span class="p">(</span>
  93. <span class="bp">self</span><span class="p">,</span> <span class="n">param_groups</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">lr</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="nb">iter</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">training_params</span><span class="p">:</span> <span class="n">HpmStruct</span><span class="p">,</span> <span class="n">total_batch</span><span class="p">:</span> <span class="nb">int</span>
  94. <span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">:</span>
  95. <span class="sd">&quot;&quot;&quot;</span>
  96. <span class="sd"> :param param_groups: list of dictionaries containing the params</span>
  97. <span class="sd"> :return: list of dictionaries containing the params</span>
  98. <span class="sd"> &quot;&quot;&quot;</span>
  99. <span class="k">for</span> <span class="n">param_group</span> <span class="ow">in</span> <span class="n">param_groups</span><span class="p">:</span>
  100. <span class="n">param_group</span><span class="p">[</span><span class="s2">&quot;lr&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">lr</span>
  101. <span class="k">return</span> <span class="n">param_groups</span></div>
  102. <div class="viewcode-block" id="SgModule.get_include_attributes"><a class="viewcode-back" href="../../../../super_gradients.training.models.html#super_gradients.training.models.sg_module.SgModule.get_include_attributes">[docs]</a> <span class="k">def</span> <span class="nf">get_include_attributes</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">:</span>
  103. <span class="sd">&quot;&quot;&quot;</span>
  104. <span class="sd"> This function is used by the EMA. When updating the EMA model, some attributes of the main model (used in training)</span>
  105. <span class="sd"> are updated to the EMA model along with the model weights.</span>
  106. <span class="sd"> By default, all attributes are updated except for private attributes (starting with &#39;_&#39;)</span>
  107. <span class="sd"> You can either set include_attributes or exclude_attributes. By returning a non empty list from this function,</span>
  108. <span class="sd"> you override the default behaviour and only attributes named in this list will be updated.</span>
  109. <span class="sd"> Note: This will also override the get_exclude_attributes list.</span>
  110. <span class="sd"> :return: list of attributes to update from main model to EMA model</span>
  111. <span class="sd"> &quot;&quot;&quot;</span>
  112. <span class="k">return</span> <span class="p">[]</span></div>
  113. <div class="viewcode-block" id="SgModule.get_exclude_attributes"><a class="viewcode-back" href="../../../../super_gradients.training.models.html#super_gradients.training.models.sg_module.SgModule.get_exclude_attributes">[docs]</a> <span class="k">def</span> <span class="nf">get_exclude_attributes</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">:</span>
  114. <span class="sd">&quot;&quot;&quot;</span>
  115. <span class="sd"> This function is used by the EMA. When updating the EMA model, some attributes of the main model (used in training)</span>
  116. <span class="sd"> are updated to the EMA model along with the model weights.</span>
  117. <span class="sd"> By default, all attributes are updated except for private attributes (starting with &#39;_&#39;)</span>
  118. <span class="sd"> You can either set include_attributes or exclude_attributes. By returning a non empty list from this function,</span>
  119. <span class="sd"> you override the default behaviour and attributes named in this list will also be excluded from update.</span>
  120. <span class="sd"> Note: if get_include_attributes is not empty, it will override this list.</span>
  121. <span class="sd"> :return: list of attributes to not update from main model to EMA mode</span>
  122. <span class="sd"> &quot;&quot;&quot;</span>
  123. <span class="k">return</span> <span class="p">[]</span></div>
  124. <div class="viewcode-block" id="SgModule.prep_model_for_conversion"><a class="viewcode-back" href="../../../../super_gradients.training.models.html#super_gradients.training.models.sg_module.SgModule.prep_model_for_conversion">[docs]</a> <span class="k">def</span> <span class="nf">prep_model_for_conversion</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_size</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  125. <span class="sd">&quot;&quot;&quot;</span>
  126. <span class="sd"> Prepare the model to be converted to ONNX or other frameworks.</span>
  127. <span class="sd"> Typically, this function will freeze the size of layers which is otherwise flexible, replace some modules</span>
  128. <span class="sd"> with convertible substitutes and remove all auxiliary or training related parts.</span>
  129. <span class="sd"> :param input_size: [H,W]</span>
  130. <span class="sd"> &quot;&quot;&quot;</span></div>
  131. <div class="viewcode-block" id="SgModule.replace_head"><a class="viewcode-back" href="../../../../super_gradients.training.models.html#super_gradients.training.models.sg_module.SgModule.replace_head">[docs]</a> <span class="k">def</span> <span class="nf">replace_head</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  132. <span class="sd">&quot;&quot;&quot;</span>
  133. <span class="sd"> Replace final layer for pretrained models. Since this varies between architectures, we leave it to the inheriting</span>
  134. <span class="sd"> class to implement.</span>
  135. <span class="sd"> &quot;&quot;&quot;</span>
  136. <span class="k">raise</span> <span class="ne">NotImplementedError</span></div></div>
  137. </pre></div>
  138. </div>
  139. </div>
  140. <footer>
  141. <hr/>
  142. <div role="contentinfo">
  143. <p>&#169; Copyright 2021, SuperGradients team.</p>
  144. </div>
  145. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  146. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  147. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  148. </footer>
  149. </div>
  150. </div>
  151. </section>
  152. </div>
  153. <script>
  154. jQuery(function () {
  155. SphinxRtdTheme.Navigation.enable(true);
  156. });
  157. </script>
  158. </body>
  159. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.params &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../" id="documentation_options" src="../../../_static/documentation_options.js"></script>
  14. <script src="../../../_static/jquery.js"></script>
  15. <script src="../../../_static/underscore.js"></script>
  16. <script src="../../../_static/doctools.js"></script>
  17. <script src="../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../welcome.html">SuperGradients</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../super_gradients.common.html">Common</a></li>
  40. <li class="toctree-l1"><a class="reference internal" href="../../../super_gradients.training.html">Training</a></li>
  41. </ul>
  42. </div>
  43. </div>
  44. </nav>
  45. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  46. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  47. <a href="../../../index.html">SuperGradients</a>
  48. </nav>
  49. <div class="wy-nav-content">
  50. <div class="rst-content">
  51. <div role="navigation" aria-label="Page navigation">
  52. <ul class="wy-breadcrumbs">
  53. <li><a href="../../../index.html" class="icon icon-home"></a> &raquo;</li>
  54. <li><a href="../../index.html">Module code</a> &raquo;</li>
  55. <li>super_gradients.training.params</li>
  56. <li class="wy-breadcrumbs-aside">
  57. </li>
  58. </ul>
  59. <hr/>
  60. </div>
  61. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  62. <div itemprop="articleBody">
  63. <h1>Source code for super_gradients.training.params</h1><div class="highlight"><pre>
  64. <span></span><span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">HpmStruct</span>
  65. <span class="n">DEFAULT_TRAINING_PARAMS</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;lr_warmup_epochs&quot;</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span>
  66. <span class="s2">&quot;cosine_final_lr_ratio&quot;</span><span class="p">:</span> <span class="mf">0.01</span><span class="p">,</span>
  67. <span class="s2">&quot;optimizer&quot;</span><span class="p">:</span> <span class="s2">&quot;SGD&quot;</span><span class="p">,</span>
  68. <span class="s2">&quot;criterion_params&quot;</span><span class="p">:</span> <span class="p">{},</span>
  69. <span class="s2">&quot;ema&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span>
  70. <span class="s2">&quot;batch_accumulate&quot;</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="c1"># number of batches to accumulate before every backward pass</span>
  71. <span class="s2">&quot;ema_params&quot;</span><span class="p">:</span> <span class="p">{},</span>
  72. <span class="s2">&quot;zero_weight_decay_on_bias_and_bn&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span>
  73. <span class="s2">&quot;load_opt_params&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
  74. <span class="s2">&quot;run_validation_freq&quot;</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span>
  75. <span class="s2">&quot;save_model&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
  76. <span class="s2">&quot;metric_to_watch&quot;</span><span class="p">:</span> <span class="s2">&quot;Accuracy&quot;</span><span class="p">,</span>
  77. <span class="s2">&quot;launch_tensorboard&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span>
  78. <span class="s2">&quot;tb_files_user_prompt&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span> <span class="c1"># Asks User for Tensorboard Deletion Prompt</span>
  79. <span class="s2">&quot;silent_mode&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span> <span class="c1"># Silents the Print outs</span>
  80. <span class="s2">&quot;mixed_precision&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span>
  81. <span class="s2">&quot;tensorboard_port&quot;</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span>
  82. <span class="s2">&quot;save_ckpt_epoch_list&quot;</span><span class="p">:</span> <span class="p">[],</span> <span class="c1"># indices where the ckpt will save automatically</span>
  83. <span class="s2">&quot;average_best_models&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
  84. <span class="s2">&quot;dataset_statistics&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span> <span class="c1"># add a dataset statistical analysis and sample images to tensorboard</span>
  85. <span class="s2">&quot;save_tensorboard_to_s3&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span>
  86. <span class="s2">&quot;lr_schedule_function&quot;</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span>
  87. <span class="s2">&quot;train_metrics_list&quot;</span><span class="p">:</span> <span class="p">[],</span>
  88. <span class="s2">&quot;valid_metrics_list&quot;</span><span class="p">:</span> <span class="p">[],</span>
  89. <span class="s2">&quot;loss_logging_items_names&quot;</span><span class="p">:</span> <span class="p">[</span><span class="s2">&quot;Loss&quot;</span><span class="p">],</span>
  90. <span class="s2">&quot;greater_metric_to_watch_is_better&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
  91. <span class="s2">&quot;precise_bn&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span>
  92. <span class="s2">&quot;precise_bn_batch_size&quot;</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span>
  93. <span class="s2">&quot;seed&quot;</span><span class="p">:</span> <span class="mi">42</span><span class="p">,</span>
  94. <span class="s2">&quot;lr_mode&quot;</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span>
  95. <span class="s2">&quot;phase_callbacks&quot;</span><span class="p">:</span> <span class="p">[],</span>
  96. <span class="s2">&quot;log_installed_packages&quot;</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span>
  97. <span class="s2">&quot;save_full_train_log&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span>
  98. <span class="s2">&quot;sg_logger&quot;</span><span class="p">:</span> <span class="s2">&quot;base_sg_logger&quot;</span><span class="p">,</span>
  99. <span class="s2">&quot;sg_logger_params&quot;</span><span class="p">:</span>
  100. <span class="p">{</span><span class="s2">&quot;tb_files_user_prompt&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span> <span class="c1"># Asks User for Tensorboard Deletion Prompt</span>
  101. <span class="s2">&quot;project_name&quot;</span><span class="p">:</span> <span class="s2">&quot;&quot;</span><span class="p">,</span>
  102. <span class="s2">&quot;launch_tensorboard&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span>
  103. <span class="s2">&quot;tensorboard_port&quot;</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span>
  104. <span class="s2">&quot;save_checkpoints_remote&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span> <span class="c1"># upload checkpoint files to s3</span>
  105. <span class="s2">&quot;save_tensorboard_remote&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span> <span class="c1"># upload tensorboard files to s3</span>
  106. <span class="s2">&quot;save_logs_remote&quot;</span><span class="p">:</span> <span class="kc">False</span><span class="p">},</span> <span class="c1"># upload log files to s3</span>
  107. <span class="s2">&quot;warmup_mode&quot;</span><span class="p">:</span> <span class="s2">&quot;linear_step&quot;</span>
  108. <span class="p">}</span>
  109. <span class="n">DEFAULT_OPTIMIZER_PARAMS_SGD</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;weight_decay&quot;</span><span class="p">:</span> <span class="mf">1e-4</span><span class="p">,</span> <span class="s2">&quot;momentum&quot;</span><span class="p">:</span> <span class="mf">0.9</span><span class="p">}</span>
  110. <span class="n">DEFAULT_OPTIMIZER_PARAMS_ADAM</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;weight_decay&quot;</span><span class="p">:</span> <span class="mf">1e-4</span><span class="p">}</span>
  111. <span class="n">DEFAULT_OPTIMIZER_PARAMS_RMSPROP</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;weight_decay&quot;</span><span class="p">:</span> <span class="mf">1e-4</span><span class="p">,</span> <span class="s2">&quot;momentum&quot;</span><span class="p">:</span> <span class="mf">0.9</span><span class="p">}</span>
  112. <span class="n">DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;weight_decay&quot;</span><span class="p">:</span> <span class="mf">1e-4</span><span class="p">,</span> <span class="s2">&quot;momentum&quot;</span><span class="p">:</span> <span class="mf">0.9</span><span class="p">}</span>
  113. <span class="n">TRAINING_PARAM_SCHEMA</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;type&quot;</span><span class="p">:</span> <span class="s2">&quot;object&quot;</span><span class="p">,</span>
  114. <span class="s2">&quot;properties&quot;</span><span class="p">:</span> <span class="p">{</span>
  115. <span class="s2">&quot;max_epochs&quot;</span><span class="p">:</span> <span class="p">{</span><span class="s2">&quot;type&quot;</span><span class="p">:</span> <span class="s2">&quot;number&quot;</span><span class="p">,</span> <span class="s2">&quot;minimum&quot;</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;maximum&quot;</span><span class="p">:</span> <span class="mi">800</span><span class="p">},</span>
  116. <span class="c1"># FIXME: CHECK THE IMPORTANCE OF THE COMMENTED SCHEMA- AS IT CAUSES HYDRA USE TO CRASH</span>
  117. <span class="c1"># &quot;lr_updates&quot;: {&quot;type&quot;: &quot;array&quot;, &quot;minItems&quot;: 1},</span>
  118. <span class="s2">&quot;lr_decay_factor&quot;</span><span class="p">:</span> <span class="p">{</span><span class="s2">&quot;type&quot;</span><span class="p">:</span> <span class="s2">&quot;number&quot;</span><span class="p">,</span> <span class="s2">&quot;minimum&quot;</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;maximum&quot;</span><span class="p">:</span> <span class="mi">1</span><span class="p">},</span>
  119. <span class="s2">&quot;lr_warmup_epochs&quot;</span><span class="p">:</span> <span class="p">{</span><span class="s2">&quot;type&quot;</span><span class="p">:</span> <span class="s2">&quot;number&quot;</span><span class="p">,</span> <span class="s2">&quot;minimum&quot;</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;maximum&quot;</span><span class="p">:</span> <span class="mi">10</span><span class="p">},</span>
  120. <span class="s2">&quot;initial_lr&quot;</span><span class="p">:</span> <span class="p">{</span><span class="s2">&quot;type&quot;</span><span class="p">:</span> <span class="s2">&quot;number&quot;</span><span class="p">,</span> <span class="s2">&quot;exclusiveMinimum&quot;</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;maximum&quot;</span><span class="p">:</span> <span class="mi">10</span><span class="p">}</span>
  121. <span class="p">},</span>
  122. <span class="s2">&quot;if&quot;</span><span class="p">:</span> <span class="p">{</span>
  123. <span class="s2">&quot;properties&quot;</span><span class="p">:</span> <span class="p">{</span><span class="s2">&quot;lr_mode&quot;</span><span class="p">:</span> <span class="p">{</span><span class="s2">&quot;const&quot;</span><span class="p">:</span> <span class="s2">&quot;step&quot;</span><span class="p">}}</span>
  124. <span class="p">},</span>
  125. <span class="s2">&quot;then&quot;</span><span class="p">:</span> <span class="p">{</span>
  126. <span class="s2">&quot;required&quot;</span><span class="p">:</span> <span class="p">[</span><span class="s2">&quot;lr_updates&quot;</span><span class="p">,</span> <span class="s2">&quot;lr_decay_factor&quot;</span><span class="p">]</span>
  127. <span class="p">},</span>
  128. <span class="s2">&quot;required&quot;</span><span class="p">:</span> <span class="p">[</span><span class="s2">&quot;max_epochs&quot;</span><span class="p">,</span> <span class="s2">&quot;lr_mode&quot;</span><span class="p">,</span> <span class="s2">&quot;initial_lr&quot;</span><span class="p">,</span> <span class="s2">&quot;loss&quot;</span><span class="p">]</span>
  129. <span class="p">}</span>
  130. <div class="viewcode-block" id="TrainingParams"><a class="viewcode-back" href="../../../super_gradients.training.html#super_gradients.training.params.TrainingParams">[docs]</a><span class="k">class</span> <span class="nc">TrainingParams</span><span class="p">(</span><span class="n">HpmStruct</span><span class="p">):</span>
  131. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">entries</span><span class="p">):</span>
  132. <span class="c1"># WE initialize by the default training params, overridden by the provided params</span>
  133. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">DEFAULT_TRAINING_PARAMS</span><span class="p">)</span>
  134. <span class="bp">self</span><span class="o">.</span><span class="n">set_schema</span><span class="p">(</span><span class="n">TRAINING_PARAM_SCHEMA</span><span class="p">)</span>
  135. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">entries</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
  136. <span class="bp">self</span><span class="o">.</span><span class="n">override</span><span class="p">(</span><span class="o">**</span><span class="n">entries</span><span class="p">)</span>
  137. <div class="viewcode-block" id="TrainingParams.override"><a class="viewcode-back" href="../../../super_gradients.training.html#super_gradients.training.params.TrainingParams.override">[docs]</a> <span class="k">def</span> <span class="nf">override</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">entries</span><span class="p">):</span>
  138. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">override</span><span class="p">(</span><span class="o">**</span><span class="n">entries</span><span class="p">)</span>
  139. <span class="bp">self</span><span class="o">.</span><span class="n">validate</span><span class="p">()</span></div></div>
  140. </pre></div>
  141. </div>
  142. </div>
  143. <footer>
  144. <hr/>
  145. <div role="contentinfo">
  146. <p>&#169; Copyright 2021, SuperGradients team.</p>
  147. </div>
  148. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  149. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  150. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  151. </footer>
  152. </div>
  153. </div>
  154. </section>
  155. </div>
  156. <script>
  157. jQuery(function () {
  158. SphinxRtdTheme.Navigation.enable(true);
  159. });
  160. </script>
  161. </body>
  162. </html>
Discard
Only showing up to 1000 lines per file, please use a local Git client to see the full diff.
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.sg_model.sg_model &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.sg_trainer.sg_trainer &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -76,7 +78,7 @@
   <ul class="wy-breadcrumbs">
   <ul class="wy-breadcrumbs">
       <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
       <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
           <li><a href="../../../index.html">Module code</a> &raquo;</li>
           <li><a href="../../../index.html">Module code</a> &raquo;</li>
-      <li>super_gradients.training.sg_model.sg_model</li>
+      <li>super_gradients.training.sg_trainer.sg_trainer</li>
       <li class="wy-breadcrumbs-aside">
       <li class="wy-breadcrumbs-aside">
       </li>
       </li>
   </ul>
   </ul>
@@ -85,74 +87,99 @@
           <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
            <div itemprop="articleBody">
            <div itemprop="articleBody">
              
              
-  <h1>Source code for super_gradients.training.sg_model.sg_model</h1><div class="highlight"><pre>
+  <h1>Source code for super_gradients.training.sg_trainer.sg_trainer</h1><div class="highlight"><pre>
 <span></span><span class="kn">import</span> <span class="nn">inspect</span>
 <span></span><span class="kn">import</span> <span class="nn">inspect</span>
 <span class="kn">import</span> <span class="nn">os</span>
 <span class="kn">import</span> <span class="nn">os</span>
-<span class="kn">import</span> <span class="nn">sys</span>
 <span class="kn">from</span> <span class="nn">copy</span> <span class="kn">import</span> <span class="n">deepcopy</span>
 <span class="kn">from</span> <span class="nn">copy</span> <span class="kn">import</span> <span class="n">deepcopy</span>
-<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Mapping</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Any</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Mapping</span><span class="p">,</span> <span class="n">Dict</span>
+<span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
 
 
 <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
 <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
-<span class="kn">import</span> <span class="nn">pkg_resources</span>
 <span class="kn">import</span> <span class="nn">torch</span>
 <span class="kn">import</span> <span class="nn">torch</span>
-<span class="kn">import</span> <span class="nn">torchvision.transforms</span> <span class="k">as</span> <span class="nn">transforms</span>
-<span class="kn">from</span> <span class="nn">deprecated</span> <span class="kn">import</span> <span class="n">deprecated</span>
+<span class="kn">import</span> <span class="nn">hydra</span>
+<span class="kn">from</span> <span class="nn">omegaconf</span> <span class="kn">import</span> <span class="n">DictConfig</span>
 <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
 <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
-<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">DistributedSampler</span>
+<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span><span class="p">,</span> <span class="n">SequentialSampler</span>
 <span class="kn">from</span> <span class="nn">torch.cuda.amp</span> <span class="kn">import</span> <span class="n">GradScaler</span><span class="p">,</span> <span class="n">autocast</span>
 <span class="kn">from</span> <span class="nn">torch.cuda.amp</span> <span class="kn">import</span> <span class="n">GradScaler</span><span class="p">,</span> <span class="n">autocast</span>
 <span class="kn">from</span> <span class="nn">torchmetrics</span> <span class="kn">import</span> <span class="n">MetricCollection</span>
 <span class="kn">from</span> <span class="nn">torchmetrics</span> <span class="kn">import</span> <span class="n">MetricCollection</span>
 <span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
 <span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span>
 <span class="kn">from</span> <span class="nn">piptools.scripts.sync</span> <span class="kn">import</span> <span class="n">_get_installed_distributions</span>
 <span class="kn">from</span> <span class="nn">piptools.scripts.sync</span> <span class="kn">import</span> <span class="n">_get_installed_distributions</span>
 
 
+<span class="kn">from</span> <span class="nn">torch.utils.data.distributed</span> <span class="kn">import</span> <span class="n">DistributedSampler</span>
+
+<span class="kn">from</span> <span class="nn">super_gradients.common.factories.type_factory</span> <span class="kn">import</span> <span class="n">TypeFactory</span>
+<span class="kn">from</span> <span class="nn">super_gradients.training.datasets.samplers</span> <span class="kn">import</span> <span class="n">InfiniteSampler</span><span class="p">,</span> <span class="n">RepeatAugSampler</span>
+
 <span class="kn">from</span> <span class="nn">super_gradients.common.factories.callbacks_factory</span> <span class="kn">import</span> <span class="n">CallbacksFactory</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.factories.callbacks_factory</span> <span class="kn">import</span> <span class="n">CallbacksFactory</span>
+<span class="kn">from</span> <span class="nn">super_gradients.common.data_types.enum</span> <span class="kn">import</span> <span class="n">MultiGPUMode</span><span class="p">,</span> <span class="n">StrictLoad</span><span class="p">,</span> <span class="n">EvaluationType</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.models.all_architectures</span> <span class="kn">import</span> <span class="n">ARCHITECTURES</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.models.all_architectures</span> <span class="kn">import</span> <span class="n">ARCHITECTURES</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.decorators.factory_decorator</span> <span class="kn">import</span> <span class="n">resolve_param</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.decorators.factory_decorator</span> <span class="kn">import</span> <span class="n">resolve_param</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.environment</span> <span class="kn">import</span> <span class="n">env_helpers</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.environment</span> <span class="kn">import</span> <span class="n">env_helpers</span>
-<span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
-<span class="kn">from</span> <span class="nn">super_gradients.common.factories.datasets_factory</span> <span class="kn">import</span> <span class="n">DatasetsFactory</span>
+<span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span><span class="p">,</span> <span class="n">mute_current_process</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.factories.list_factory</span> <span class="kn">import</span> <span class="n">ListFactory</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.factories.list_factory</span> <span class="kn">import</span> <span class="n">ListFactory</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.factories.losses_factory</span> <span class="kn">import</span> <span class="n">LossesFactory</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.factories.losses_factory</span> <span class="kn">import</span> <span class="n">LossesFactory</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.factories.metrics_factory</span> <span class="kn">import</span> <span class="n">MetricsFactory</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.factories.metrics_factory</span> <span class="kn">import</span> <span class="n">MetricsFactory</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.sg_loggers</span> <span class="kn">import</span> <span class="n">SG_LOGGERS</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.sg_loggers</span> <span class="kn">import</span> <span class="n">SG_LOGGERS</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.sg_loggers.abstract_sg_logger</span> <span class="kn">import</span> <span class="n">AbstractSGLogger</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.sg_loggers.abstract_sg_logger</span> <span class="kn">import</span> <span class="n">AbstractSGLogger</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.sg_loggers.base_sg_logger</span> <span class="kn">import</span> <span class="n">BaseSGLogger</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.sg_loggers.base_sg_logger</span> <span class="kn">import</span> <span class="n">BaseSGLogger</span>
-<span class="kn">from</span> <span class="nn">super_gradients.training</span> <span class="kn">import</span> <span class="n">utils</span> <span class="k">as</span> <span class="n">core_utils</span>
+<span class="kn">from</span> <span class="nn">super_gradients.training</span> <span class="kn">import</span> <span class="n">utils</span> <span class="k">as</span> <span class="n">core_utils</span><span class="p">,</span> <span class="n">models</span><span class="p">,</span> <span class="n">dataloaders</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.models</span> <span class="kn">import</span> <span class="n">SgModule</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.models</span> <span class="kn">import</span> <span class="n">SgModule</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.pretrained_models</span> <span class="kn">import</span> <span class="n">PRETRAINED_NUM_CLASSES</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.pretrained_models</span> <span class="kn">import</span> <span class="n">PRETRAINED_NUM_CLASSES</span>
-<span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">sg_model_utils</span>
-<span class="kn">from</span> <span class="nn">super_gradients.training.utils.quantization_utils</span> <span class="kn">import</span> <span class="n">QATCallback</span>
-<span class="kn">from</span> <span class="nn">super_gradients.training.utils.sg_model_utils</span> <span class="kn">import</span> <span class="n">MonitoredValue</span>
-<span class="kn">from</span> <span class="nn">super_gradients.training</span> <span class="kn">import</span> <span class="n">metrics</span>
-<span class="kn">from</span> <span class="nn">super_gradients.training.exceptions.sg_model_exceptions</span> <span class="kn">import</span> <span class="n">UnsupportedOptimizerFormat</span><span class="p">,</span> \
-    <span class="n">IllegalDataloaderInitialization</span>
-<span class="kn">from</span> <span class="nn">super_gradients.training.datasets</span> <span class="kn">import</span> <span class="n">DatasetInterface</span>
+<span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">sg_trainer_utils</span>
+<span class="kn">from</span> <span class="nn">super_gradients.training.utils.sg_trainer_utils</span> <span class="kn">import</span> <span class="n">MonitoredValue</span><span class="p">,</span> <span class="n">parse_args</span><span class="p">,</span> <span class="n">log_main_training_params</span>
+<span class="kn">from</span> <span class="nn">super_gradients.training.exceptions.sg_trainer_exceptions</span> <span class="kn">import</span> <span class="n">UnsupportedOptimizerFormat</span><span class="p">,</span> <span class="n">GPUModeNotSetupError</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.losses</span> <span class="kn">import</span> <span class="n">LOSSES</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.losses</span> <span class="kn">import</span> <span class="n">LOSSES</span>
-<span class="kn">from</span> <span class="nn">super_gradients.training.metrics.metric_utils</span> <span class="kn">import</span> <span class="n">get_metrics_titles</span><span class="p">,</span> <span class="n">get_metrics_results_tuple</span><span class="p">,</span> \
-    <span class="n">get_logging_values</span><span class="p">,</span> \
-    <span class="n">get_metrics_dict</span><span class="p">,</span> <span class="n">get_train_loop_description_dict</span>
+<span class="kn">from</span> <span class="nn">super_gradients.training.metrics.metric_utils</span> <span class="kn">import</span> <span class="p">(</span>
+    <span class="n">get_metrics_titles</span><span class="p">,</span>
+    <span class="n">get_metrics_results_tuple</span><span class="p">,</span>
+    <span class="n">get_logging_values</span><span class="p">,</span>
+    <span class="n">get_metrics_dict</span><span class="p">,</span>
+    <span class="n">get_train_loop_description_dict</span><span class="p">,</span>
+<span class="p">)</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.params</span> <span class="kn">import</span> <span class="n">TrainingParams</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.params</span> <span class="kn">import</span> <span class="n">TrainingParams</span>
-<span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">DetectionPostPredictionCallback</span>
-<span class="kn">from</span> <span class="nn">super_gradients.training.utils.distributed_training_utils</span> <span class="kn">import</span> <span class="n">MultiGPUModeAutocastWrapper</span><span class="p">,</span> \
-    <span class="n">reduce_results_tuple_for_ddp</span><span class="p">,</span> <span class="n">compute_precise_bn_stats</span>
+<span class="kn">from</span> <span class="nn">super_gradients.training.utils.distributed_training_utils</span> <span class="kn">import</span> <span class="p">(</span>
+    <span class="n">MultiGPUModeAutocastWrapper</span><span class="p">,</span>
+    <span class="n">reduce_results_tuple_for_ddp</span><span class="p">,</span>
+    <span class="n">compute_precise_bn_stats</span><span class="p">,</span>
+    <span class="n">setup_device</span><span class="p">,</span>
+    <span class="n">require_gpu_setup</span><span class="p">,</span>
+    <span class="n">get_gpu_mem_utilization</span><span class="p">,</span>
+    <span class="n">get_world_size</span><span class="p">,</span>
+    <span class="n">get_local_rank</span><span class="p">,</span>
+    <span class="n">wait_for_the_master</span><span class="p">,</span>
+<span class="p">)</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.ema</span> <span class="kn">import</span> <span class="n">ModelEMA</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.ema</span> <span class="kn">import</span> <span class="n">ModelEMA</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.optimizer_utils</span> <span class="kn">import</span> <span class="n">build_optimizer</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.optimizer_utils</span> <span class="kn">import</span> <span class="n">build_optimizer</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.weight_averaging_utils</span> <span class="kn">import</span> <span class="n">ModelWeightAveraging</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils.weight_averaging_utils</span> <span class="kn">import</span> <span class="n">ModelWeightAveraging</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.metrics</span> <span class="kn">import</span> <span class="n">Accuracy</span><span class="p">,</span> <span class="n">Top5</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.metrics</span> <span class="kn">import</span> <span class="n">Accuracy</span><span class="p">,</span> <span class="n">Top5</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">random_seed</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">random_seed</span>
-<span class="kn">from</span> <span class="nn">super_gradients.training.utils.checkpoint_utils</span> <span class="kn">import</span> <span class="n">get_ckpt_local_path</span><span class="p">,</span> <span class="n">read_ckpt_state_dict</span><span class="p">,</span> \
-    <span class="n">load_checkpoint_to_model</span><span class="p">,</span> <span class="n">load_pretrained_weights</span>
+<span class="kn">from</span> <span class="nn">super_gradients.training.utils.checkpoint_utils</span> <span class="kn">import</span> <span class="p">(</span>
+    <span class="n">get_ckpt_local_path</span><span class="p">,</span>
+    <span class="n">read_ckpt_state_dict</span><span class="p">,</span>
+    <span class="n">load_checkpoint_to_model</span><span class="p">,</span>
+    <span class="n">load_pretrained_weights</span><span class="p">,</span>
+    <span class="n">get_checkpoints_dir_path</span><span class="p">,</span>
+<span class="p">)</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.datasets_utils</span> <span class="kn">import</span> <span class="n">DatasetStatisticsTensorboardLogger</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.datasets.datasets_utils</span> <span class="kn">import</span> <span class="n">DatasetStatisticsTensorboardLogger</span>
-<span class="kn">from</span> <span class="nn">super_gradients.training.utils.callbacks</span> <span class="kn">import</span> <span class="n">CallbackHandler</span><span class="p">,</span> <span class="n">Phase</span><span class="p">,</span> <span class="n">LR_SCHEDULERS_CLS_DICT</span><span class="p">,</span> <span class="n">PhaseContext</span><span class="p">,</span> \
-    <span class="n">MetricsUpdateCallback</span><span class="p">,</span> <span class="n">LR_WARMUP_CLS_DICT</span><span class="p">,</span> <span class="n">ContextSgMethods</span><span class="p">,</span> <span class="n">LRCallbackBase</span>
+<span class="kn">from</span> <span class="nn">super_gradients.training.utils.callbacks</span> <span class="kn">import</span> <span class="p">(</span>
+    <span class="n">CallbackHandler</span><span class="p">,</span>
+    <span class="n">Phase</span><span class="p">,</span>
+    <span class="n">LR_SCHEDULERS_CLS_DICT</span><span class="p">,</span>
+    <span class="n">PhaseContext</span><span class="p">,</span>
+    <span class="n">MetricsUpdateCallback</span><span class="p">,</span>
+    <span class="n">LR_WARMUP_CLS_DICT</span><span class="p">,</span>
+    <span class="n">ContextSgMethods</span><span class="p">,</span>
+    <span class="n">LRCallbackBase</span><span class="p">,</span>
+<span class="p">)</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.environment</span> <span class="kn">import</span> <span class="n">environment_config</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common.environment</span> <span class="kn">import</span> <span class="n">environment_config</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">HpmStruct</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">HpmStruct</span>
-<span class="kn">from</span> <span class="nn">super_gradients.training.datasets.samplers.infinite_sampler</span> <span class="kn">import</span> <span class="n">InfiniteSampler</span>
+<span class="kn">from</span> <span class="nn">super_gradients.training.utils.hydra_utils</span> <span class="kn">import</span> <span class="n">load_experiment_cfg</span><span class="p">,</span> <span class="n">add_params_to_cfg</span>
+<span class="kn">from</span> <span class="nn">omegaconf</span> <span class="kn">import</span> <span class="n">OmegaConf</span>
 
 
-<span class="kn">from</span> <span class="nn">super_gradients.common</span> <span class="kn">import</span> <span class="n">StrictLoad</span><span class="p">,</span> <span class="n">MultiGPUMode</span><span class="p">,</span> <span class="n">EvaluationType</span>
 <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
 <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
 
 
 
 
-<div class="viewcode-block" id="SgModel"><a class="viewcode-back" href="../../../../super_gradients.training.sg_model.html#super_gradients.training.SgModel">[docs]</a><span class="k">class</span> <span class="nc">SgModel</span><span class="p">:</span>
+<div class="viewcode-block" id="Trainer"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.Trainer">[docs]</a><span class="k">class</span> <span class="nc">Trainer</span><span class="p">:</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    SuperGradient Model - Base Class for Sg Models</span>
 <span class="sd">    SuperGradient Model - Base Class for Sg Models</span>
 
 
@@ -168,30 +195,17 @@
 <span class="sd">        returns the test loss, accuracy and runtime</span>
 <span class="sd">        returns the test loss, accuracy and runtime</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 
 
-    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">multi_gpu</span><span class="p">:</span> <span class="n">Union</span
-                 <span class="n">model_checkpoints_location</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;local&#39;</span><span class="p">,</span>
-                 <span class="n">overwrite_local_checkpoint</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">ckpt_name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;ckpt_latest.pth&#39;</span><span class="p">,</span>
-                 <span class="n">post_prediction_callback</span><span class="p">:</span> <span class="n">DetectionPostPredictionCallback</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">ckpt_root_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
-                 <span class="n">train_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">valid_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="kc">Non
-                 <span class="n">classes</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Any</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
+    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">multi_gpu</span><span class="p">:</span> <span class="n">Union</span
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 
 
 <span class="sd">        :param experiment_name:                      Used for logging and loading purposes</span>
 <span class="sd">        :param experiment_name:                      Used for logging and loading purposes</span>
 <span class="sd">        :param device:                          If equal to &#39;cpu&#39; runs on the CPU otherwise on GPU</span>
 <span class="sd">        :param device:                          If equal to &#39;cpu&#39; runs on the CPU otherwise on GPU</span>
 <span class="sd">        :param multi_gpu:                       If True, runs on all available devices</span>
 <span class="sd">        :param multi_gpu:                       If True, runs on all available devices</span>
-<span class="sd">        :param model_checkpoints_location:      If set to &#39;s3&#39; saves the Checkpoints in AWS S3</span>
 <span class="sd">                                                otherwise saves the Checkpoints Locally</span>
 <span class="sd">                                                otherwise saves the Checkpoints Locally</span>
-<span class="sd">        :param overwrite_local_checkpoint:      If set to False keeps the current local checkpoint when importing</span>
 <span class="sd">                                                checkpoint from cloud service, otherwise overwrites the local checkpoints file</span>
 <span class="sd">                                                checkpoint from cloud service, otherwise overwrites the local checkpoints file</span>
-<span class="sd">        :param ckpt_name:                       The Checkpoint to Load</span>
 <span class="sd">        :param ckpt_root_dir:                   Local root directory path where all experiment logging directories will</span>
 <span class="sd">        :param ckpt_root_dir:                   Local root directory path where all experiment logging directories will</span>
 <span class="sd">                                                reside. When none is give, it is assumed that</span>
 <span class="sd">                                                reside. When none is give, it is assumed that</span>
 <span class="sd">                                                pkg_resources.resource_filename(&#39;checkpoints&#39;, &quot;&quot;) exists and will be used.</span>
 <span class="sd">                                                pkg_resources.resource_filename(&#39;checkpoints&#39;, &quot;&quot;) exists and will be used.</span>
-<span class="sd">        :param train_loader:                    Training set Dataloader instead of using DatasetInterface, must pass &quot;valid_loader&quot;</span>
-<span class="sd">                                                and &quot;classes&quot; along with it</span>
-<span class="sd">        :param valid_loader:                    Validation set Dataloader</span>
-<span class="sd">        :param test_loader:                     Test set Dataloader</span>
-<span class="sd">        :param classes:                         List of class labels</span>
 
 
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
         <span class="c1"># SET THE EMPTY PROPERTIES</span>
         <span class="c1"># SET THE EMPTY PROPERTIES</span>
@@ -201,7 +215,6 @@
         <span class="bp">self</span><span class="o">.</span><span class="n">ema_model</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">ema_model</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">update_param_groups</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">update_param_groups</span> <span class="o">=</span> <span class="kc">None</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">post_prediction_callback</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">criterion</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">criterion</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">scaler</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">scaler</span> <span class="o">=</span> <span class="kc">None</span>
@@ -217,157 +230,189 @@
         <span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span> <span class="o">=</span> <span class="kc">False</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span> <span class="o">=</span> <span class="kc">False</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">source_ckpt_folder_name</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">source_ckpt_folder_name</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">model_weight_averaging</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">model_weight_averaging</span> <span class="o">=</span> <span class="kc">None</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">average_model_checkpoint_filename</span> <span class="o">=</span> <span class="s1">&#39;average_model.pth&#39;</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">average_model_checkpoint_filename</span> <span class="o">=</span> <span class="s2">&quot;average_model.pth&quot;</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">start_epoch</span> <span class="o">=</span> <span class="mi">0</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">start_epoch</span> <span class="o">=</span> <span class="mi">0</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">best_metric</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">inf</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">best_metric</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">inf</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">external_checkpoint_path</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">external_checkpoint_path</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">strict_load</span> <span class="o">=</span> <span class="n">StrictLoad</span><span class="o">.</span><span class="n">ON</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">strict_load</span> <span class="o">=</span> <span class="n">StrictLoad</span><span class="o">.</span><span class="n">ON</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">load_ema_as_net</span> <span class="o">=</span> <span class="kc">False</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">load_ema_as_net</span> <span class="o">=</span> <span class="kc">False</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">ckpt_best_name</span> <span class="o">=</span> <span class="s1">&#39;ckpt_best.pth&#39;</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">ckpt_best_name</span> <span class="o">=</span> <span class="s2">&quot;ckpt_best.pth&quot;</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">enable_qat</span> <span class="o">=</span> <span class="kc">False</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">enable_qat</span> <span class="o">=</span> <span class="kc">False</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">qat_params</span> <span class="o">=</span> <span class="p">{}</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">qat_params</span> <span class="o">=</span> <span class="p">{}</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">_infinite_train_loader</span> <span class="o">=</span> <span class="kc">False</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">_infinite_train_loader</span> <span class="o">=</span> <span class="kc">False</span>
-
-        <span class="c1"># DETERMINE THE LOCATION OF THE LOSS AND ACCURACY IN THE RESULTS TUPLE OUTPUTED BY THE TEST</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">loss_idx_in_results_tuple</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">acc_idx_in_results_tuple</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_first_backward</span> <span class="o">=</span> <span class="kc">True</span>
 
 
         <span class="c1"># METRICS</span>
         <span class="c1"># METRICS</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">valid_metrics</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">valid_metrics</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">greater_metric_to_watch_is_better</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">greater_metric_to_watch_is_better</span> <span class="o">=</span> <span class="kc">None</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">metric_to_watch</span> <span class="o">=</span> <span class="kc">None</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">greater_train_metrics_is_better</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">bool</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>  <span class="c1"># For each metric, indicates if greater is better</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">greater_valid_metrics_is_better</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">bool</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
 
 
         <span class="c1"># SETTING THE PROPERTIES FROM THE CONSTRUCTOR</span>
         <span class="c1"># SETTING THE PROPERTIES FROM THE CONSTRUCTOR</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">experiment_name</span> <span class="o">=</span> <span class="n">experiment_name</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">experiment_name</span> <span class="o">=</span> <span class="n">experiment_name</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">ckpt_name</span> <span class="o">=</span> <span class="n">ckpt_name</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">overwrite_local_checkpoint</span> <span class="o">=</span> <span class="n">overwrite_local_checkpoint</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">model_checkpoints_location</span> <span class="o">=</span> <span class="n">model_checkpoints_location</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">_set_dataset_properties</span><span class="p">(</span><span class="n">classes</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">valid_loader</span><span class="p">)</span>
-
-        <span class="c1"># CREATING THE LOGGING DIR BASED ON THE INPUT PARAMS TO PREVENT OVERWRITE OF LOCAL VERSION</span>
-        <span class="k">if</span> <span class="n">ckpt_root_dir</span><span class="p">:</span>
-            <span class="bp">self</span><span class="o">.</span><span class="n">checkpoints_dir_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">ckpt_root_dir</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">experiment_name</span><span class="p">)</span>
-        <span class="k">elif</span> <span class="n">pkg_resources</span><span class="o">.</span><span class="n">resource_exists</span><span class="p">(</span><span class="s2">&quot;checkpoints&quot;</span><span class="p">,</span> <span class="s2">&quot;&quot;</span><span class="p">):</span>
-            <span class="bp">self</span><span class="o">.</span><span class="n">checkpoints_dir_path</span> <span class="o">=</span> <span class="n">pkg_resources</span><span class="o">.</span><span class="n">resource_filename</span><span class="p">(</span><span class="s1">&#39;checkpoints&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">experiment_name</span><span class="p">)</span>
-        <span class="k">else</span><span class="p">:</span>
-            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Illegal checkpoints directory: pass ckpt_root_dir that exists, or add &#39;checkpoints&#39; to&quot;</span>
-                             <span class="s2">&quot;resources.&quot;</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">ckpt_name</span> <span class="o">=</span> <span class="kc">None</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">checkpoints_dir_path</span> <span class="o">=</span> <span class="n">get_checkpoints_dir_path</span><span class="p">(</span><span class="n">experiment_name</span><span class="p">,</span> <span class="n">ckpt_root_dir</span><span class="p">)</span>
 
 
         <span class="c1"># INITIALIZE THE DEVICE FOR THE MODEL</span>
         <span class="c1"># INITIALIZE THE DEVICE FOR THE MODEL</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">_initialize_device</span><span class="p">(</span><span class="n">requested_device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">requested_multi_gpu</span><span class="o">=</span><span class="n">multi_gpu</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">_initialize_device</span><span class="p">(</span><span class="n">requested_device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span> <span class="n">requested_multi_gpu</span><span class="o">=</span><span class="n">multi_gpu</span><span class="p">)</span>
 
 
-        <span class="bp">self</span><span class="o">.</span><span class="n">post_prediction_callback</span> <span class="o">=</span> <span class="n">post_prediction_callback</span>
         <span class="c1"># SET THE DEFAULTS</span>
         <span class="c1"># SET THE DEFAULTS</span>
         <span class="c1"># TODO: SET DEFAULT TRAINING PARAMS FOR EACH TASK</span>
         <span class="c1"># TODO: SET DEFAULT TRAINING PARAMS FOR EACH TASK</span>
 
 
-        <span class="n">default_results_titles</span> <span class="o">=</span> <span class="p">[</span><span class="s1">&#39;Train Loss&#39;</span><span class="p">,</span> <span class="s1">&#39;Train Acc&#39;</span><span class="p">,</span> <span class="s1">&#39;Train Top5&#39;</span><span class="p">,</span> <span class="s1">&#39;Valid Loss&#39;</span><span class="p">,</span> <span class="s1">&#39;Valid Acc&#39;</span><span class="p">,</span> <span class="s1">&#39;Valid Top5&#39;</span><span cla
+        <span class="n">default_results_titles</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;Train Loss&quot;</span><span class="p">,</span> <span class="s2">&quot;Train Acc&quot;</span><span class="p">,</span> <span class="s2">&quot;Train Top5&quot;</span><span class="p">,</span> <span class="s2">&quot;Valid Loss&quot;</span><span class="p">,</span> <span class="s2">&quot;Valid Acc&quot;</span><span class="p">,</span> <span class="s2">&quot;Valid Top5&quot;</sp
 
 
         <span class="bp">self</span><span class="o">.</span><span class="n">results_titles</span> <span class="o">=</span> <span class="n">default_results_titles</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">results_titles</span> <span class="o">=</span> <span class="n">default_results_titles</span>
 
 
-        <span class="bp">self</span><span class="o">.</span><span class="n">loss_idx_in_results_tuple</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">acc_idx_in_results_tuple</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span>
-        <span class="n">default_train_metrics</span><span class="p">,</span> <span class="n">default_valid_metrics</span> <span class="o">=</span> <span class="n">MetricCollection</span><span class="p">([</span><span class="n">Accuracy</span><span class="p">(),</span> <span class="n">Top5</span><span class="p">()]),</span> <span class="n">MetricCollection</span><span class="p">(</span>
-            <span class="p">[</span><span class="n">Accuracy</span><span class="p">(),</span> <span class="n">Top5</span><span class="p">()])</span>
-
-        <span class="n">default_loss_logging_items_names</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;Loss&quot;</span><span class="p">]</span>
+        <span class="n">default_train_metrics</span><span class="p">,</span> <span class="n">default_valid_metrics</span> <span class="o">=</span> <span class="n">MetricCollection</span><span class="p">([</span><span class="n">Accuracy</span><span class="p">(),</span> <span class="n">Top5</span><span class="p">()]),</span> <span class="n">MetricCollection</span><span class="p">([</span><span class="n">Accuracy</span><span class="p">(),</span> <span class="n">Top5</span><span class="p">()])</spa
 
 
         <span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_metrics</span> <span class="o">=</span> <span class="n">default_train_metrics</span><span class="p">,</span> <span class="n">default_valid_metrics</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_metrics</span> <span class="o">=</span> <span class="n">default_train_metrics</span><span class="p">,</span> <span class="n">default_valid_metrics</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="o">=</span> <span class="n">default_loss_logging_items_names</span>
 
 
         <span class="bp">self</span><span class="o">.</span><span class="n">train_monitored_values</span> <span class="o">=</span> <span class="p">{}</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">train_monitored_values</span> <span class="o">=</span> <span class="p">{}</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">valid_monitored_values</span> <span class="o">=</span> <span class="p">{}</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">valid_monitored_values</span> <span class="o">=</span> <span class="p">{}</span>
 
 
-    <span class="k">def</span> <span class="nf">_set_dataset_properties</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">classes</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">,</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">valid_loader</span><span class="p">):</span>
-        <span class="k">if</span> <span class="nb">any</span><span class="p">([</span><span class="n">train_loader</span><span class="p">,</span> <span class="n">valid_loader</span><span class="p">,</span> <span class="n">classes</span><span class="p">])</span> <span class="ow">and</span> <span class="ow">not</span> <span class="nb">all</span><span class="p">([</span><span class="n">train_loader</span><span class="p">,</span> <span class="n">valid_loader</span><span class="p">,</span> <span cla
-            <span class="k">raise</span> <span class="n">IllegalDataloaderInitialization</span><span class="p">()</span>
+<div class="viewcode-block" id="Trainer.train_from_config"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.Trainer.train_from_config">[docs]</a>    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">train_from_config</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">cfg</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">DictConfig</span><span class="p">,</span> <span class="nb">dict</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">nn</span><span class="o">.</span><span 
+        <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">        Trains according to cfg recipe configuration.</span>
 
 
-        <span class="n">dataset_params</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;batch_size&quot;</span><span class="p">:</span> <span class="n">train_loader</span><span class="o">.</span><span class="n">batch_size</span> <span class="k">if</span> <span class="n">train_loader</span> <span class="k">else</span> <span class="kc">None</span><span class="p">,</span>
-                          <span class="s2">&quot;val_batch_size&quot;</span><span class="p">:</span> <span class="n">valid_loader</span><span class="o">.</span><span class="n">batch_size</span> <span class="k">if</span> <span class="n">valid_loader</span> <span class="k">else</span> <span class="kc">None</span><span class="p">,</span>
-                          <span class="s2">&quot;test_batch_size&quot;</span><span class="p">:</span> <span class="n">test_loader</span><span class="o">.</span><span class="n">batch_size</span> <span class="k">if</span> <span class="n">test_loader</span> <span class="k">else</span> <span class="kc">None</span><span class="p">,</span>
-                          <span class="s2">&quot;dataset_dir&quot;</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span>
-                          <span class="s2">&quot;s3_link&quot;</span><span class="p">:</span> <span class="kc">None</span><span class="p">}</span>
+<span class="sd">        @param cfg: The parsed DictConfig from yaml recipe files or a dictionary</span>
+<span class="sd">        @return: the model and the output of trainer.train(...) (i.e results tuple)</span>
+<span class="sd">        &quot;&quot;&quot;</span>
 
 
-        <span class="k">if</span> <span class="n">train_loader</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">multi_gpu</span> <span class="o">==</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">DISTRIBUTED_DATA_PARALLEL</span><span class="p">:</span>
-            <span class="k">if</span> <span class="ow">not</span> <span class="nb">all</span><span class="p">([</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">train_loader</span><span class="o">.</span><span class="n">sampler</span><span class="p">,</span> <span class="n">DistributedSampler</span><span class="p">),</span>
-                        <span class="nb">isinstance</span><span class="p">(</span><span class="n">valid_loader</span><span class="o">.</span><span class="n">sampler</span><span class="p">,</span> <span class="n">DistributedSampler</span><span class="p">),</span>
-                        <span class="n">test_loader</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">test_loader</span><span class="o">.</span><span class="n">sampler</span><span class="p">,</span> <span class="n">DistributedSampler</span><span class="p">)]):</span>
-                <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;DDP training was selected but the dataloader samplers are not of type DistributedSamplers&quot;</span><span class="p">)</span>
+        <span class="n">setup_device</span><span class="p">(</span><span class="n">multi_gpu</span><span class="o">=</span><span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="s2">&quot;multi_gpu&quot;</span><span class="p">,</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">OFF</span><span class="p">),</span> <span class="n">num_gpus</span>
+
+        <span class="c1"># INSTANTIATE ALL OBJECTS IN CFG</span>
+        <span class="n">cfg</span> <span class="o">=</span> <span class="n">hydra</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">instantiate</span><span class="p">(</span><span class="n">cfg</span><span class="p">)</span>
+
+        <span class="n">kwargs</span> <span class="o">=</span> <span class="n">parse_args</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="bp">cls</span><span class="o">.</span><span class="fm">__init__</span><span class="p">)</span>
+
+        <span class="n">trainer</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
+
+        <span class="c1"># INSTANTIATE DATA LOADERS</span>
+        <span class="n">train_dataloader</span> <span class="o">=</span> <span class="n">dataloaders</span><span class="o">.</span><span class="n">get</span><span class="p">(</span>
+            <span class="n">name</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">train_dataloader</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">train_dataset_params</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="n">cf
+        <span class="p">)</span>
+
+        <span class="n">val_dataloader</span> <span class="o">=</span> <span class="n">dataloaders</span><span class="o">.</span><span class="n">get</span><span class="p">(</span>
+            <span class="n">name</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">val_dataloader</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_dataset_params</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="n">cfg</s
+        <span class="p">)</span>
+
+        <span class="c1"># BUILD NETWORK</span>
+        <span class="n">model</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">get</span><span class="p">(</span>
+            <span class="n">model_name</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">architecture</span><span class="p">,</span>
+            <span class="n">num_classes</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">arch_params</span><span class="o">.</span><span class="n">num_classes</span><span class="p">,</span>
+            <span class="n">arch_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">arch_params</span><span class="p">,</span>
+            <span class="n">strict_load</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="o">.</span><span class="n">strict_load</span><span class="p">,</span>
+            <span class="n">pretrained_weights</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="o">.</span><span class="n">pretrained_weights</span><span class="p">,</span>
+            <span class="n">checkpoint_path</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="o">.</span><span class="n">checkpoint_path</span><span class="p">,</span>
+            <span class="n">load_backbone</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="o">.</span><span class="n">load_backbone</span><span class="p">,</span>
+        <span class="p">)</span>
+        <span class="n">recipe_logged_cfg</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;recipe_config&quot;</span><span class="p">:</span> <span class="n">OmegaConf</span><span class="o">.</span><span class="n">to_container</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="n">resolve</span><span class="o">=</span><span class="kc">True</span><span class="p">)}</span>
+        <span class="c1"># TRAIN</span>
+        <span class="n">res</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="n">train</span><span class="p">(</span>
+            <span class="n">model</span><span class="o">=</span><span class="n">model</span><span class="p">,</span>
+            <span class="n">train_loader</span><span class="o">=</span><span class="n">train_dataloader</span><span class="p">,</span>
+            <span class="n">valid_loader</span><span class="o">=</span><span class="n">val_dataloader</span><span class="p">,</span>
+            <span class="n">training_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">training_hyperparams</span><span class="p">,</span>
+            <span class="n">additional_configs_to_log</span><span class="o">=</span><span class="n">recipe_logged_cfg</span><span class="p">,</span>
+        <span class="p">)</span>
+
+        <span class="k">return</span> <span class="n">model</span><span class="p">,</span> <span class="n">res</span></div>
+
+<div class="viewcode-block" id="Trainer.resume_experiment"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.Trainer.resume_experiment">[docs]</a>    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">resume_experiment</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">ckpt_root_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class=
+        <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">        Resume a training that was run using our recipes.</span>
 
 
-        <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">test_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o
-            <span class="n">HpmStruct</span><span class="p">(</span><span class="o">**</span><span class="n">dataset_params</span><span class="p">),</span> <span class="n">train_loader</span><span class="p">,</span> <span class="n">valid_loader</span><span class="p">,</span> <span class="n">test_loader</span><span class="p">,</span> <span class="n">classes</span>
+<span class="sd">        :param experiment_name:     Name of the experiment to resume</span>
+<span class="sd">        :param ckpt_root_dir:       Directory including the checkpoints</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+        <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Resume training using the checkpoint recipe, ignoring the current recipe&quot;</span><span class="p">)</span>
+        <span class="n">cfg</span> <span class="o">=</span> <span class="n">load_experiment_cfg</span><span class="p">(</span><span class="n">experiment_name</span><span class="p">,</span> <span class="n">ckpt_root_dir</span><span class="p">)</span>
+        <span class="n">add_params_to_cfg</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="n">params</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;training_hyperparams.resume=True&quot;</span><span class="p">])</span>
+        <span class="bp">cls</span><span class="o">.</span><span class="n">train_from_config</span><span class="p">(</span><span class="n">cfg</span><span class="p">)</span></div>
 
 
-<div class="viewcode-block" id="SgModel.connect_dataset_interface"><a class="viewcode-back" href="../../../../super_gradients.training.sg_model.html#super_gradients.training.SgModel.connect_dataset_interface">[docs]</a>    <span class="nd">@resolve_param</span><span class="p">(</span><span class="s1">&#39;dataset_interface&#39;</span><span class="p">,</span> <span class="n">DatasetsFactory</span><span class="p">())</span>
-    <span class="k">def</span> <span class="nf">connect_dataset_interface</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_interface</span><span class="p">:</span> <span class="n">DatasetInterface</span><span class="p">,</span> <span class="n">data_loader_num_workers</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">):</span>
+<div class="viewcode-block" id="Trainer.evaluate_from_recipe"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.Trainer.evaluate_from_recipe">[docs]</a>    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">evaluate_from_recipe</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">cfg</span><span class="p">:</span> <span class="n">DictConfig</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
-<span class="sd">        :param dataset_interface: DatasetInterface object</span>
-<span class="sd">        :param data_loader_num_workers: The number of threads to initialize the Data Loaders with</span>
-<span class="sd">            The dataset to be connected</span>
+<span class="sd">        Evaluate according to a cfg recipe configuration.</span>
+
+<span class="sd">        Note:   This script does NOT run training, only validation.</span>
+<span class="sd">                Please make sure that the config refers to a PRETRAINED MODEL either from one of your checkpoint or from pretrained weights from model zoo.</span>
+<span class="sd">        :param cfg: The parsed DictConfig from yaml recipe files or a dictionary</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
-        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">:</span>
-            <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;Overriding the dataloaders that SgModel was initialized with&quot;</span><span class="p">)</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">dataset_interface</span> <span class="o">=</span> <span class="n">dataset_interface</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">test_loader</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> \
-            <span class="bp">self</span><span class="o">.</span><span class="n">dataset_interface</span><span class="o">.</span><span class="n">get_data_loaders</span><span class="p">(</span><span class="n">batch_size_factor</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_devices</span><span class="p">,</span>
-                                                    <span class="n">num_workers</span><span class="o">=</span><span class="n">data_loader_num_workers</span><span class="p">,</span>
-                                                    <span class="n">distributed_sampler</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">multi_gpu</span> <span class="o">==</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">DISTRIBUTED_DATA_PARALLEL</span><span class="p">)</span>
 
 
-        <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_interface</span><span class="o">.</span><span class="n">get_dataset_params</span><span class="p">()</span></div>
+        <span class="n">setup_device</span><span class="p">(</span><span class="n">multi_gpu</span><span class="o">=</span><span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="s2">&quot;multi_gpu&quot;</span><span class="p">,</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">OFF</span><span class="p">),</span> <span class="n">num_gpus</span>
 
 
-    <span class="c1"># FIXME - we need to resolve flake8&#39;s &#39;function is too complex&#39; for this function</span>
-<div class="viewcode-block" id="SgModel.build_model"><a class="viewcode-back" href="../../../../super_gradients.training.sg_model.html#super_gradients.training.SgModel.build_model">[docs]</a>    <span class="k">def</span> <span class="nf">build_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>  <span class="c1"># noqa: C901 - too complex</span>
-                    <span class="n">architecture</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">],</span>
-                    <span class="n">arch_params</span><span class="o">=</span><span class="p">{},</span> <span class="n">checkpoint_params</span><span class="o">=</span><span class="p">{},</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
-        <span class="sd">&quot;&quot;&quot;</span>
-<span class="sd">        :param architecture:               Defines the network&#39;s architecture from models/ALL_ARCHITECTURES</span>
-<span class="sd">        :param arch_params:                Architecture H.P. e.g.: block, num_blocks, num_classes, etc.</span>
-<span class="sd">        :param checkpoint_params:          Dictionary like object with the following key:values:</span>
-
-<span class="sd">            load_checkpoint:            Load a pre-trained checkpoint</span>
-<span class="sd">            strict_load:                See StrictLoad class documentation for details.</span>
-<span class="sd">            source_ckpt_folder_name:    folder name to load the checkpoint from (self.experiment_name if none is given)</span>
-<span class="sd">            load_weights_only:          loads only the weight from the checkpoint and zeroize the training params</span>
-<span class="sd">            load_backbone:              loads the provided checkpoint to self.net.backbone instead of self.net</span>
-<span class="sd">            external_checkpoint_path:   The path to the external checkpoint to be loaded. Can be absolute or relative</span>
-<span class="sd">                                               (ie: path/to/checkpoint.pth). If provided, will automatically attempt to</span>
-<span class="sd">                                               load the checkpoint even if the load_checkpoint flag is not provided.</span>
+        <span class="c1"># INSTANTIATE ALL OBJECTS IN CFG</span>
+        <span class="n">cfg</span> <span class="o">=</span> <span class="n">hydra</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">instantiate</span><span class="p">(</span><span class="n">cfg</span><span class="p">)</span>
 
 
-<span class="sd">        &quot;&quot;&quot;</span>
-        <span class="k">if</span> <span class="s1">&#39;num_classes&#39;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">arch_params</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
-            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_interface</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
-                <span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">&#39;Error&#39;</span><span class="p">,</span> <span class="s1">&#39;Number of classes not defined in arch params and dataset is not defined&#39;</span><span class="p">)</span>
-            <span class="k">else</span><span class="p">:</span>
-                <span class="n">arch_params</span><span class="p">[</span><span class="s1">&#39;num_classes&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">classes</span><span class="p">)</span>
+        <span class="n">kwargs</span> <span class="o">=</span> <span class="n">parse_args</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="bp">cls</span><span class="o">.</span><span class="fm">__init__</span><span class="p">)</span>
 
 
-        <span class="bp">self</span><span class="o">.</span><span class="n">arch_params</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">HpmStruct</span><span class="p">(</span><span class="o">**</span><span class="n">arch_params</span><span class="p">)</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">HpmStruct</span><span class="p">(</span><span class="o">**</span><span class="n">checkpoint_params</span><span class="p">)</span>
+        <span class="n">trainer</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
 
 
-        <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_instantiate_net</span><span class="p">(</span><span class="n">architecture</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">arch_params</span><span class="p">,</span> <span class="n">checkpoint_params</span><span class="p">,</span> <span class="o">*</span><span class=
+        <span class="c1"># INSTANTIATE DATA LOADERS</span>
+        <span class="n">val_dataloader</span> <span class="o">=</span> <span class="n">dataloaders</span><span class="o">.</span><span class="n">get</span><span class="p">(</span>
+            <span class="n">name</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">val_dataloader</span><span class="p">,</span> <span class="n">dataset_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">dataset_params</span><span class="o">.</span><span class="n">val_dataset_params</span><span class="p">,</span> <span class="n">dataloader_params</span><span class="o">=</span><span class="n">cfg</s
+        <span class="p">)</span>
 
 
-        <span class="c1"># SAVE THE ARCHITECTURE FOR NEURAL ARCHITECTURE SEARCH</span>
+        <span class="n">checkpoints_dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="n">get_checkpoints_dir_path</span><span class="p">(</span><span class="n">experiment_name</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">experiment_name</span><span class="p">,</span> <span class="n">ckpt_root_dir</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">ckpt
+        <span class="n">checkpoint_path</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="n">checkpoints_dir</span> <span class="o">/</span> <span class="n">cfg</span><span class="o">.</span><span class="n">training_hyperparams</span><span class="o">.</span><span class="n">ckpt_name</span><span class="p">)</span>
+        <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Evaluating checkpoint: </span><span class="si">{</span><span class="n">checkpoint_path</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
 
 
-        <span class="bp">self</span><span class="o">.</span><span class="n">architecture</span> <span class="o">=</span> <span class="n">architecture</span>
+        <span class="c1"># BUILD NETWORK</span>
+        <span class="n">model</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">get</span><span class="p">(</span>
+            <span class="n">model_name</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">architecture</span><span class="p">,</span>
+            <span class="n">num_classes</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">arch_params</span><span class="o">.</span><span class="n">num_classes</span><span class="p">,</span>
+            <span class="n">arch_params</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">arch_params</span><span class="p">,</span>
+            <span class="n">pretrained_weights</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="o">.</span><span class="n">pretrained_weights</span><span class="p">,</span>
+            <span class="n">checkpoint_path</span><span class="o">=</span><span class="n">checkpoint_path</span><span class="p">,</span>
+            <span class="n">load_backbone</span><span class="o">=</span><span class="n">cfg</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="o">.</span><span class="n">load_backbone</span><span class="p">,</span>
+        <span class="p">)</span>
 
 
-        <span class="bp">self</span><span class="o">.</span><span class="n">_net_to_device</span><span class="p">()</span>
+        <span class="c1"># TEST</span>
+        <span class="n">val_results_tuple</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="n">test</span><span class="p">(</span><span class="n">model</span><span class="o">=</span><span class="n">model</span><span class="p">,</span> <span class="n">test_loader</span><span class="o">=</span><span class="n">val_dataloader</span><span class="p">,</span> <span class="n">test_metrics_list</span><span class="o">=</span><span class="n">cfg</span><span
 
 
-        <span class="c1"># SET THE FLAG FOR DIFFERENT PARAMETER GROUP OPTIMIZER UPDATE</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">update_param_groups</span> <span class="o">=</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="p">,</span> <span class="s1">&#39;update_param_groups&#39;</span><span class="p">)</span>
+        <span class="n">valid_metrics_dict</span> <span class="o">=</span> <span class="n">get_metrics_dict</span><span class="p">(</span><span class="n">val_results_tuple</span><span class="p">,</span> <span class="n">trainer</span><span class="o">.</span><span class="n">test_metrics</span><span class="p">,</span> <span class="n">trainer</span><span class="o">.</span><span class="n">loss_logging_items_names</span><span class="p">)</span>
 
 
-        <span class="bp">self</span><span class="o">.</span><span class="n">_load_checkpoint_to_model</span><span class="p">()</span></div>
+        <span class="n">results</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;Validate Results&quot;</span><span class="p">]</span>
+        <span class="n">results</span> <span class="o">+=</span> <span class="p">[</span><span class="sa">f</span><span class="s2">&quot;   - </span><span class="si">{</span><span class="n">metric</span><span class="si">:</span><span class="s2">10</span><span class="si">}</span><span class="s2">: </span><span class="si">{</span><span class="n">value</span><span class="si">}</span><span class="s2">&quot;</span> <span class="k">for</span> <span class="n">metric</span><span class="p">,</span> <spa
+        <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">results</span><span class="p">))</span></div>
 
 
-    <span class="k">def</span> <span class="nf">_set_ckpt_loading_attributes</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+<div class="viewcode-block" id="Trainer.evaluate_checkpoint"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.Trainer.evaluate_checkpoint">[docs]</a>    <span class="nd">@classmethod</span>
+    <span class="k">def</span> <span class="nf">evaluate_checkpoint</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">ckpt_name</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;ckpt_latest.pth&quot;</span><span class="p">,</span> <span class="n">ckpt_root_dir</span><span class=
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
-<span class="sd">        Sets checkpoint loading related attributes according to self.checkpoint_params</span>
+<span class="sd">        Evaluate a checkpoint resulting from one of your previous experiment, using the same parameters (dataset, valid_metrics,...)</span>
+<span class="sd">        as used during the training of the experiment</span>
+
+<span class="sd">        Note:</span>
+<span class="sd">            The parameters will be unchanged even if the recipe used for that experiment was changed since then.</span>
+<span class="sd">            This is to ensure that validation of the experiment will remain exactly the same as during training.</span>
+
+<span class="sd">        Example, evaluate the checkpoint &quot;average_model.pth&quot; from experiment &quot;my_experiment_name&quot;:</span>
+<span class="sd">            &gt;&gt; evaluate_checkpoint(experiment_name=&quot;my_experiment_name&quot;, ckpt_name=&quot;average_model.pth&quot;)</span>
+
+<span class="sd">        :param experiment_name:     Name of the experiment to validate</span>
+<span class="sd">        :param ckpt_name:           Name of the checkpoint to test (&quot;ckpt_latest.pth&quot;, &quot;average_model.pth&quot; or &quot;ckpt_best.pth&quot; for instance)</span>
+<span class="sd">        :param ckpt_root_dir:       Directory including the checkpoints</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">checkpoint</span> <span class="o">=</span> <span class="p">{}</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">strict_load</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s1">&#39;strict_load&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="n">Stric
-        <span class="bp">self</span><span class="o">.</span><span class="n">load_ema_as_net</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s1">&#39;load_ema_as_net&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="
-        <span class="bp">self</span><span class="o">.</span><span class="n">source_ckpt_folder_name</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s1">&#39;source_ckpt_folder_name&#39;</span><span class="p">)</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">load_checkpoint</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s1">&#39;load_checkpoint&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="
-        <span class="bp">self</span><span class="o">.</span><span class="n">load_backbone</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s1">&#39;load_backbone&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="kc">
-        <span class="bp">self</span><span class="o">.</span><span class="n">external_checkpoint_path</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s1">&#39;external_checkpoint_path&#39;</span><span class="p">)</span>
-        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">load_checkpoint</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">external_checkpoint_path</span><span class="p">:</span>
-            <span class="bp">self</span><span class="o">.</span><span class="n">load_weights_only</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s1">&#39;load_weights_only&#39;</span><span class="p">,</span>
-                                                          <span class="n">default_val</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">ckpt_name</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span><span class="p">,</span> <span class="s1">&#39;ckpt_name&#39;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="bp">self</sp
+        <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Evaluate checkpoint&quot;</span><span class="p">)</span>
+        <span class="n">cfg</span> <span class="o">=</span> <span class="n">load_experiment_cfg</span><span class="p">(</span><span class="n">experiment_name</span><span class="p">,</span> <span class="n">ckpt_root_dir</span><span class="p">)</span>
+        <span class="n">add_params_to_cfg</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="n">params</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;training_hyperparams.resume=True&quot;</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;ckpt_name=</span><span class="si">{</span><span class="n">ckpt_name</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">])</span>
+        <span class="bp">cls</span><span class="o">.</span><span class="n">evaluate_from_recipe</span><span class="p">(</span><span class="n">cfg</span><span class="p">)</span></div>
+
+    <span class="k">def</span> <span class="nf">_set_dataset_params</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span> <span class="o">=</span> <span class="p">{</span>
+            <span class="s2">&quot;train_dataset_params&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="o">.</span><span class="n">dataset</span><span class="o">.</span><span class="n">dataset_params</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="o">.</span><span class
+            <span class="s2">&quot;train_dataloader_params&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="o">.</span><span class="n">dataloader_params</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">,</span> <span class="s2">&quot;dataloader_params&quot;</span><span
+            <span class="s2">&quot;valid_dataset_params&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span><span class="o">.</span><span class="n">dataset</span><span class="o">.</span><span class="n">dataset_params</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span><span class="o">.</span><span class
+            <span class="s2">&quot;valid_dataloader_params&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span><span class="o">.</span><span class="n">dataloader_params</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span><span class="p">,</span> <span class="s2">&quot;dataloader_params&quot;</span><span
+        <span class="p">}</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span> <span class="o">=</span> <span class="n">HpmStruct</span><span class="p">(</span><span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_params</span><span class="p">)</span>
 
 
     <span class="k">def</span> <span class="nf">_net_to_device</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">_net_to_device</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
@@ -376,20 +421,17 @@
         <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
 
 
         <span class="c1"># FOR MULTI-GPU TRAINING (not distributed)</span>
         <span class="c1"># FOR MULTI-GPU TRAINING (not distributed)</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">arch_params</span><span class="o">.</span><span class="n">sync_bn</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">arch_params</span><span class="p">,</span> <span class="s1">&#39;sync_bn&#39;</span><span class="p">,</span> <span class="n">default_val</span><
+        <span class="n">sync_bn</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="p">,</span> <span class="s2">&quot;sync_bn&quot;</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
         <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">multi_gpu</span> <span class="o">==</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">DATA_PARALLEL</span><span class="p">:</span>
         <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">multi_gpu</span> <span class="o">==</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">DATA_PARALLEL</span><span class="p">:</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">DataParallel</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">,</span> <span class="n">device_ids</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class
             <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">DataParallel</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">,</span> <span class="n">device_ids</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class
         <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">multi_gpu</span> <span class="o">==</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">DISTRIBUTED_DATA_PARALLEL</span><span class="p">:</span>
         <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">multi_gpu</span> <span class="o">==</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">DISTRIBUTED_DATA_PARALLEL</span><span class="p">:</span>
-            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">arch_params</span><span class="o">.</span><span class="n">sync_bn</span><span class="p">:</span>
+            <span class="k">if</span> <span class="n">sync_bn</span><span class="p">:</span>
                 <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span><span class="p">:</span>
                 <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span><span class="p">:</span>
-                    <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s1">&#39;DDP - Using Sync Batch Norm... Training time will be affected accordingly&#39;</span><span class="p">)</span>
+                    <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;DDP - Using Sync Batch Norm... Training time will be affected accordingly&quot;</span><span class="p">)</span>
                 <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">SyncBatchNorm</span><span class="o">.</span><span class="n">convert_sync_batchnorm</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">)</span><span class="o">.</span><span class="n">to</spa
                 <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">SyncBatchNorm</span><span class="o">.</span><span class="n">convert_sync_batchnorm</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">)</span><span class="o">.</span><span class="n">to</spa
 
 
-            <span class="n">local_rank</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">&#39;:&#39;</span><span class="p">)[</span><span class="mi">1</span><span class="p">])</span>
-            <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">parallel</span><span class="o">.</span><span class="n">DistributedDataParallel</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">,</span>
-                                                                 <span class="n">device_ids</span><span class="o">=</span><span class="p">[</span><span class="n">local_rank</span><span class="p">],</span>
-                                                                 <span class="n">output_device</span><span class="o">=</span><span class="n">local_rank</span><span class="p">,</span>
-                                                                 <span class="n">find_unused_parameters</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+            <span class="n">local_rank</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;:&quot;</span><span class="p">)[</span><span class="mi">1</span><span class="p">])</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">parallel</span><span class="o">.</span><span class="n">DistributedDataParallel</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">,</span> <span class="n">device_ids</span><span class="o">=</sp
 
 
         <span class="k">else</span><span class="p">:</span>
         <span class="k">else</span><span class="p">:</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">WrappedModel</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">WrappedModel</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">)</span>
@@ -404,8 +446,7 @@
         <span class="c1"># SET THE MODEL IN training STATE</span>
         <span class="c1"># SET THE MODEL IN training STATE</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>
         <span class="c1"># THE DISABLE FLAG CONTROLS WHETHER THE PROGRESS BAR IS SILENT OR PRINTS THE LOGS</span>
         <span class="c1"># THE DISABLE FLAG CONTROLS WHETHER THE PROGRESS BAR IS SILENT OR PRINTS THE LOGS</span>
-        <span class="n">progress_bar_train_loader</span> <span class="o">=</span> <span class="n">tqdm</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">,</span> <span class="n">bar_format</span><span class="o">=</span><span class="s2">&quot;</span><span class="si">{l_bar}{bar:10}{r_bar}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">dynamic_ncols</span><span class="o">=</span><span 
-                                         <span class="n">disable</span><span class="o">=</span><span class="n">silent_mode</span><span class="p">)</span>
+        <span class="n">progress_bar_train_loader</span> <span class="o">=</span> <span class="n">tqdm</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">,</span> <span class="n">bar_format</span><span class="o">=</span><span class="s2">&quot;</span><span class="si">{l_bar}{bar:10}{r_bar}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">dynamic_ncols</span><span class="o">=</span><span 
         <span class="n">progress_bar_train_loader</span><span class="o">.</span><span class="n">set_description</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Train epoch </span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
         <span class="n">progress_bar_train_loader</span><span class="o">.</span><span class="n">set_description</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Train epoch </span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
 
 
         <span class="c1"># RESET/INIT THE METRIC LOGGERS</span>
         <span class="c1"># RESET/INIT THE METRIC LOGGERS</span>
@@ -414,21 +455,23 @@
         <span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
         <span class="n">loss_avg_meter</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">AverageMeter</span><span class="p">()</span>
         <span class="n">loss_avg_meter</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">AverageMeter</span><span class="p">()</span>
 
 
-        <span class="n">context</span> <span class="o">=</span> <span class="n">PhaseContext</span><span class="p">(</span><span class="n">epoch</span><span class="o">=</span><span class="n">epoch</span><span class="p">,</span>
-                               <span class="n">optimizer</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span>
-                               <span class="n">metrics_compute_fn</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="p">,</span>
-                               <span class="n">loss_avg_meter</span><span class="o">=</span><span class="n">loss_avg_meter</span><span class="p">,</span>
-                               <span class="n">criterion</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">criterion</span><span class="p">,</span>
-                               <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
-                               <span class="n">lr_warmup_epochs</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span><span class="p">,</span>
-                               <span class="n">sg_logger</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="p">,</span>
-                               <span class="n">train_loader</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">,</span>
-                               <span class="n">context_methods</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_context_methods</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_BATCH_END</span><span class="p">),</span>
-                               <span class="n">ddp_silent_mode</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span><span class="p">)</span>
+        <span class="n">context</span> <span class="o">=</span> <span class="n">PhaseContext</span><span class="p">(</span>
+            <span class="n">epoch</span><span class="o">=</span><span class="n">epoch</span><span class="p">,</span>
+            <span class="n">optimizer</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span>
+            <span class="n">metrics_compute_fn</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="p">,</span>
+            <span class="n">loss_avg_meter</span><span class="o">=</span><span class="n">loss_avg_meter</span><span class="p">,</span>
+            <span class="n">criterion</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">criterion</span><span class="p">,</span>
+            <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
+            <span class="n">lr_warmup_epochs</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span><span class="p">,</span>
+            <span class="n">sg_logger</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="p">,</span>
+            <span class="n">train_loader</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="p">,</span>
+            <span class="n">context_methods</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">_get_context_methods</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_BATCH_END</span><span class="p">),</span>
+            <span class="n">ddp_silent_mode</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span><span class="p">,</span>
+        <span class="p">)</span>
 
 
         <span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="n">batch_items</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">progress_bar_train_loader</span><span class="p">):</span>
         <span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="n">batch_items</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">progress_bar_train_loader</span><span class="p">):</span>
             <span class="n">batch_items</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">tensor_container_to_device</span><span class="p">(</span><span class="n">batch_items</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
             <span class="n">batch_items</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">tensor_container_to_device</span><span class="p">(</span><span class="n">batch_items</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
-            <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">additional_batch_items</span> <span class="o">=</span> <span class="n">sg_model_utils</span><span class="o">.</span><span class="n">unpack_batch_items</span><span class="p">(</span><span class="n">batch_items</span><span class="p">)</span>
+            <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">additional_batch_items</span> <span class="o">=</span> <span class="n">sg_trainer_utils</span><span class="o">.</span><span class="n">unpack_batch_items</span><span class="p">(</span><span class="n">batch_items</span><span class="p">)</span>
 
 
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_prediction_callback</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
             <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_prediction_callback</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
                 <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_prediction_callback</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">)</span>
                 <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_prediction_callback</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">)</span>
@@ -440,12 +483,7 @@
                 <span class="c1"># COMPUTE THE LOSS FOR BACK PROP + EXTRA METRICS COMPUTED DURING THE LOSS FORWARD PASS</span>
                 <span class="c1"># COMPUTE THE LOSS FOR BACK PROP + EXTRA METRICS COMPUTED DURING THE LOSS FORWARD PASS</span>
                 <span class="n">loss</span><span class="p">,</span> <span class="n">loss_log_items</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_losses</span><span class="p">(</span><span class="n">outputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span>
                 <span class="n">loss</span><span class="p">,</span> <span class="n">loss_log_items</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_losses</span><span class="p">(</span><span class="n">outputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span>
 
 
-            <span class="n">context</span><span class="o">.</span><span class="n">update_context</span><span class="p">(</span><span class="n">batch_idx</span><span class="o">=</span><span class="n">batch_idx</span><span class="p">,</span>
-                                   <span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span>
-                                   <span class="n">preds</span><span class="o">=</span><span class="n">outputs</span><span class="p">,</span>
-                                   <span class="n">target</span><span class="o">=</span><span class="n">targets</span><span class="p">,</span>
-                                   <span class="n">loss_log_items</span><span class="o">=</span><span class="n">loss_log_items</span><span class="p">,</span>
-                                   <span class="o">**</span><span class="n">additional_batch_items</span><span class="p">)</span>
+            <span class="n">context</span><span class="o">.</span><span class="n">update_context</span><span class="p">(</span><span class="n">batch_idx</span><span class="o">=</span><span class="n">batch_idx</span><span class="p">,</span> <span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">preds</span><span class="o">=</span><span class="n">outputs</span><span class="p">,</span> <span class="n">target</span><span class="o">
 
 
             <span class="bp">self</span><span class="o">.</span><span class="n">phase_callback_handler</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_BATCH_END</span><span class="p">,</span> <span class="n">context</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">phase_callback_handler</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_BATCH_END</span><span class="p">,</span> <span class="n">context</span><span class="p">)</span>
 
 
@@ -457,13 +495,12 @@
 
 
             <span class="c1"># COMPUTE THE RUNNING USER METRICS AND LOSS RUNNING ITEMS. RESULT TUPLE IS THEIR CONCATENATION.</span>
             <span class="c1"># COMPUTE THE RUNNING USER METRICS AND LOSS RUNNING ITEMS. RESULT TUPLE IS THEIR CONCATENATION.</span>
             <span class="n">logging_values</span> <span class="o">=</span> <span class="n">loss_avg_meter</span><span class="o">.</span><span class="n">average</span> <span class="o">+</span> <span class="n">get_metrics_results_tuple</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="p">)</span>
             <span class="n">logging_values</span> <span class="o">=</span> <span class="n">loss_avg_meter</span><span class="o">.</span><span class="n">average</span> <span class="o">+</span> <span class="n">get_metrics_results_tuple</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="p">)</span>
-            <span class="n">gpu_memory_utilization</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">memory_cached</span><span class="p">()</span> <span class="o">/</span> <span class="mf">1E9</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span cl
+            <span class="n">gpu_memory_utilization</span> <span class="o">=</span> <span class="n">get_gpu_mem_utilization</span><span class="p">()</span> <span class="o">/</span> <span class="mf">1e9</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="mi">0</span>
 
 
             <span class="c1"># RENDER METRICS PROGRESS</span>
             <span class="c1"># RENDER METRICS PROGRESS</span>
-            <span class="n">pbar_message_dict</span> <span class="o">=</span> <span class="n">get_train_loop_description_dict</span><span class="p">(</span><span class="n">logging_values</span><span class="p">,</span>
-                                                                <span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="p">,</span>
-                                                                <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span><span class="p">,</span>
-                                                                <span class="n">gpu_mem</span><span class="o">=</span><span class="n">gpu_memory_utilization</span><span class="p">)</span>
+            <span class="n">pbar_message_dict</span> <span class="o">=</span> <span class="n">get_train_loop_description_dict</span><span class="p">(</span>
+                <span class="n">logging_values</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span><span class="p">,</span> <span class="n">gpu_mem</span><span class="o">=</span><span class="n">gpu_memory_utilization</span>
+            <span class="p">)</span>
 
 
             <span class="n">progress_bar_train_loader</span><span class="o">.</span><span class="n">set_postfix</span><span class="p">(</span><span class="o">**</span><span class="n">pbar_message_dict</span><span class="p">)</span>
             <span class="n">progress_bar_train_loader</span><span class="o">.</span><span class="n">set_postfix</span><span class="p">(</span><span class="o">**</span><span class="n">pbar_message_dict</span><span class="p">)</span>
 
 
@@ -475,8 +512,9 @@
         <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span><span class="p">:</span>
         <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span><span class="p">:</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">upload</span><span class="p">()</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">upload</span><span class="p">()</span>
 
 
-        <span class="bp">self</span><span class="o">.</span><span class="n">train_monitored_values</span> <span class="o">=</span> <span class="n">sg_model_utils</span><span class="o">.</span><span class="n">update_monitored_values_dict</span><span class="p">(</span>
-            <span class="n">monitored_values_dict</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">train_monitored_values</span><span class="p">,</span> <span class="n">new_values_dict</span><span class="o">=</span><span class="n">pbar_message_dict</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">train_monitored_values</span> <span class="o">=</span> <span class="n">sg_trainer_utils</span><span class="o">.</span><span class="n">update_monitored_values_dict</span><span class="p">(</span>
+            <span class="n">monitored_values_dict</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">train_monitored_values</span><span class="p">,</span> <span class="n">new_values_dict</span><span class="o">=</span><span class="n">pbar_message_dict</span>
+        <span class="p">)</span>
 
 
         <span class="k">return</span> <span class="n">logging_values</span>
         <span class="k">return</span> <span class="n">logging_values</span>
 
 
@@ -489,12 +527,50 @@
         <span class="k">else</span><span class="p">:</span>
         <span class="k">else</span><span class="p">:</span>
             <span class="n">loss_logging_items</span> <span class="o">=</span> <span class="n">loss</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
             <span class="n">loss_logging_items</span> <span class="o">=</span> <span class="n">loss</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
 
 
+        <span class="c1"># ON FIRST BACKWARD, DERRIVE THE LOGGING TITLES.</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">_first_backward</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_init_loss_logging_names</span><span class="p">(</span><span class="n">loss_logging_items</span><span class="p">)</span>
+            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">metric_to_watch</span><span class="p">:</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">_init_monitored_items</span><span class="p">()</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_first_backward</span> <span class="o">=</span> <span class="kc">False</span>
+
         <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">loss_logging_items</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span><span class="p">):</span>
         <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">loss_logging_items</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span><span class="p">):</span>
-            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Loss output length must match loss_logging_items_names. Got &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span>
-                <span class="nb">len</span><span class="p">(</span><span class="n">loss_logging_items</span><span class="p">))</span> <span class="o">+</span> <span class="s1">&#39;, and &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span><span class="p">)))</span>
+            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
+                <span class="s2">&quot;Loss output length must match loss_logging_items_names. Got &quot;</span>
+                <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">loss_logging_items</span><span class="p">))</span>
+                <span class="o">+</span> <span class="s2">&quot;, and &quot;</span>
+                <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span><span class="p">))</span>
+            <span class="p">)</span>
         <span class="c1"># RETURN AND THE LOSS LOGGING ITEMS COMPUTED DURING LOSS FORWARD PASS</span>
         <span class="c1"># RETURN AND THE LOSS LOGGING ITEMS COMPUTED DURING LOSS FORWARD PASS</span>
         <span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">loss_logging_items</span>
         <span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">loss_logging_items</span>
 
 
+    <span class="k">def</span> <span class="nf">_init_monitored_items</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">metric_idx_in_results_tuple</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="o">+</span> <span class="n">get_metrics_titles</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_metrics</span><span class="p">))</span><span class="o">.</span><span class=
+        <span class="c1"># Instantiate the values to monitor (loss/metric)</span>
+        <span class="k">for</span> <span class="n">loss_name</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">train_monitored_values</span><span class="p">[</span><span class="n">loss_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">MonitoredValue</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">loss_name</span><span class="p">,</span> <span class="n">greater_is_better</span><span class="o">=</span><span class="kc">False</span><span class="p">)</spa
+            <span class="bp">self</span><span class="o">.</span><span class="n">valid_monitored_values</span><span class="p">[</span><span class="n">loss_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">MonitoredValue</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">loss_name</span><span class="p">,</span> <span class="n">greater_is_better</span><span class="o">=</span><span class="kc">False</span><span class="p">)</spa
+
+        <span class="k">for</span> <span class="n">metric_name</span> <span class="ow">in</span> <span class="n">get_metrics_titles</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="p">):</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">train_monitored_values</span><span class="p">[</span><span class="n">metric_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">MonitoredValue</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">metric_name</span><span class="p">,</span> <span class="n">greater_is_better</span><span class="o">=</span><span class="bp">self</span><span class="o">.</
+
+        <span class="k">for</span> <span class="n">metric_name</span> <span class="ow">in</span> <span class="n">get_metrics_titles</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_metrics</span><span class="p">):</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">valid_monitored_values</span><span class="p">[</span><span class="n">metric_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">MonitoredValue</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">metric_name</span><span class="p">,</span> <span class="n">greater_is_better</span><span class="o">=</span><span class="bp">self</span><span class="o">.</
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">results_titles</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;Train_&quot;</span> <span class="o">+</span> <span class="n">t</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="o">+</span> <span class="n">get_metrics_titles</span><span class="p
+            <span class="s2">&quot;Valid_&quot;</span> <span class="o">+</span> <span class="n">t</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="o">+</span> <span class="n">get_metrics_titles</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_metrics</span><span class="p">)</span>
+        <span class="p">]</span>
+
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">average_best_models</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">model_weight_averaging</span> <span class="o">=</span> <span class="n">ModelWeightAveraging</span><span class="p">(</span>
+                <span class="bp">self</span><span class="o">.</span><span class="n">checkpoints_dir_path</span><span class="p">,</span>
+                <span class="n">greater_is_better</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">greater_metric_to_watch_is_better</span><span class="p">,</span>
+                <span class="n">source_ckpt_folder_name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">source_ckpt_folder_name</span><span class="p">,</span>
+                <span class="n">metric_to_watch</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_to_watch</span><span class="p">,</span>
+                <span class="n">metric_idx</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_idx_in_results_tuple</span><span class="p">,</span>
+                <span class="n">load_checkpoint</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">load_checkpoint</span><span class="p">,</span>
+            <span class="p">)</span>
+
     <span class="k">def</span> <span class="nf">_backward_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">loss</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="nb">int</span><span 
     <span class="k">def</span> <span class="nf">_backward_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">loss</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="nb">int</span><span 
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        Run backprop on the loss and perform a step</span>
 <span class="sd">        Run backprop on the loss and perform a step</span>
@@ -527,44 +603,43 @@
             <span class="c1"># RUN PHASE CALLBACKS</span>
             <span class="c1"># RUN PHASE CALLBACKS</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">phase_callback_handler</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_BATCH_STEP</span><span class="p">,</span> <span class="n">context</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">phase_callback_handler</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_BATCH_STEP</span><span class="p">,</span> <span class="n">context</span><span class="p">)</span>
 
 
-    <span class="k">def</span> <span class="nf">_save_checkpoint</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">validation_results_tuple</span><span class="p">:</span> <span class
-                         <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
+    <span class="k">def</span> <span class="nf">_save_checkpoint</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">validation_results_tuple</span><span class="p">:</span> <span class
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        Save the current state dict as latest (always), best (if metric was improved), epoch# (if determined in training</span>
 <span class="sd">        Save the current state dict as latest (always), best (if metric was improved), epoch# (if determined in training</span>
 <span class="sd">        params)</span>
 <span class="sd">        params)</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
         <span class="c1"># WHEN THE validation_results_tuple IS NONE WE SIMPLY SAVE THE state_dict AS LATEST AND Return</span>
         <span class="c1"># WHEN THE validation_results_tuple IS NONE WE SIMPLY SAVE THE state_dict AS LATEST AND Return</span>
         <span class="k">if</span> <span class="n">validation_results_tuple</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">validation_results_tuple</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
-            <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="s1">&#39;ckpt_latest_weights_only.pth&#39;</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;net&#39;</span><span class="p">:</span> <span class="bp">self</span><span clas
-                                          <span class="n">global_step</span><span class="o">=</span><span class="n">epoch</span><span class="p">)</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="s2">&quot;ckpt_latest_weights_only.pth&quot;</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span class="p">{</span><span class="s2">&quot;net&quot;</span><span class="p">:</span> <span class="bp">self</span><span 
             <span class="k">return</span>
             <span class="k">return</span>
 
 
         <span class="c1"># COMPUTE THE CURRENT metric</span>
         <span class="c1"># COMPUTE THE CURRENT metric</span>
         <span class="c1"># IF idx IS A LIST - SUM ALL THE VALUES STORED IN THE LIST&#39;S INDICES</span>
         <span class="c1"># IF idx IS A LIST - SUM ALL THE VALUES STORED IN THE LIST&#39;S INDICES</span>
-        <span class="n">metric</span> <span class="o">=</span> <span class="n">validation_results_tuple</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_idx_in_results_tuple</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span>
-            <span class="bp">self</span><span class="o">.</span><span class="n">metric_idx_in_results_tuple</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="k">else</span> \
-            <span class="nb">sum</span><span class="p">([</span><span class="n">validation_results_tuple</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metric_idx_in_results_tuple</span><span class="p">])</span>
+        <span class="n">metric</span> <span class="o">=</span> <span class="p">(</span>
+            <span class="n">validation_results_tuple</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_idx_in_results_tuple</span><span class="p">]</span>
+            <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_idx_in_results_tuple</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span>
+            <span class="k">else</span> <span class="nb">sum</span><span class="p">([</span><span class="n">validation_results_tuple</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metric_idx_in_results_tuple</span><span class="p">])</span>
+        <span class="p">)</span>
 
 
         <span class="c1"># BUILD THE state_dict</span>
         <span class="c1"># BUILD THE state_dict</span>
-        <span class="n">state</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;net&#39;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="s1">&#39;acc&#39;</span><span class="p">:</span> <span class="n">metric</span><span class="p">,</span> <span class="s1">&#39;epoch&#39;</span><span class="p">:</span> <span cla
+        <span class="n">state</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;net&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="s2">&quot;acc&quot;</span><span class="p">:</span> <span class="n">metric</span><span class="p">,</span> <span class="s2">&quot;epoch&quot;</span><span class="p">:</span> <sp
         <span class="k">if</span> <span class="n">optimizer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">optimizer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
-            <span class="n">state</span><span class="p">[</span><span class="s1">&#39;optimizer_state_dict&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span>
+            <span class="n">state</span><span class="p">[</span><span class="s2">&quot;optimizer_state_dict&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span>
 
 
         <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scaler</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
         <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scaler</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
-            <span class="n">state</span><span class="p">[</span><span class="s1">&#39;scaler_state_dict&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scaler</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span>
+            <span class="n">state</span><span class="p">[</span><span class="s2">&quot;scaler_state_dict&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scaler</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span>
 
 
         <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ema</span><span class="p">:</span>
         <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ema</span><span class="p">:</span>
-            <span class="n">state</span><span class="p">[</span><span class="s1">&#39;ema_net&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ema_model</span><span class="o">.</span><span class="n">ema</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span>
+            <span class="n">state</span><span class="p">[</span><span class="s2">&quot;ema_net&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ema_model</span><span class="o">.</span><span class="n">ema</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span>
         <span class="c1"># SAVES CURRENT MODEL AS ckpt_latest</span>
         <span class="c1"># SAVES CURRENT MODEL AS ckpt_latest</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="s1">&#39;ckpt_latest.pth&#39;</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">global_step</span><span class="o">=</span><span class="n">epoch</span><s
+        <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="s2">&quot;ckpt_latest.pth&quot;</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">global_step</span><span class="o">=</span><span class="n">epoch</span>
 
 
         <span class="c1"># SAVE MODEL AT SPECIFIC EPOCHS DETERMINED BY save_ckpt_epoch_list</span>
         <span class="c1"># SAVE MODEL AT SPECIFIC EPOCHS DETERMINED BY save_ckpt_epoch_list</span>
         <span class="k">if</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">save_ckpt_epoch_list</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">save_ckpt_epoch_list</span><span class="p">:</span>
-            <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="sa">f</span><span class="s1">&#39;ckpt_epoch_</span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s1">.pth&#39;</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span clas
+            <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;ckpt_epoch_</span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s2">.pth&quot;</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span cl
 
 
         <span class="c1"># OVERRIDE THE BEST CHECKPOINT AND best_metric IF metric GOT BETTER THAN THE PREVIOUS BEST</span>
         <span class="c1"># OVERRIDE THE BEST CHECKPOINT AND best_metric IF metric GOT BETTER THAN THE PREVIOUS BEST</span>
-        <span class="k">if</span> <span class="p">(</span><span class="n">metric</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">best_metric</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_metric_to_watch_is_better</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span>
-                <span class="n">metric</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">best_metric</span> <span class="ow">and</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_metric_to_watch_is_better</span><span class="p">):</span>
+        <span class="k">if</span> <span class="p">(</span><span class="n">metric</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">best_metric</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_metric_to_watch_is_better</span><span class="p">)</span> <span class="ow">or</span> <span class="p">(</span><span class="n">metric</span> <span class="o">&lt;</span> <span class="bp">self</spa
             <span class="c1"># STORE THE CURRENT metric AS BEST</span>
             <span class="c1"># STORE THE CURRENT metric AS BEST</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">best_metric</span> <span class="o">=</span> <span class="n">metric</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">best_metric</span> <span class="o">=</span> <span class="n">metric</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">_save_best_checkpoint</span><span class="p">(</span><span class="n">epoch</span><span class="p">,</span> <span class="n">state</span><span class="p">)</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">_save_best_checkpoint</span><span class="p">(</span><span class="n">epoch</span><span class="p">,</span> <span class="n">state</span><span class="p">)</span>
@@ -578,16 +653,50 @@
 
 
         <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">average_best_models</span><span class="p">:</span>
         <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">average_best_models</span><span class="p">:</span>
             <span class="n">net_for_averaging</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ema_model</span><span class="o">.</span><span class="n">ema</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ema</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span>
             <span class="n">net_for_averaging</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ema_model</span><span class="o">.</span><span class="n">ema</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ema</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span>
-            <span class="n">averaged_model_sd</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_weight_averaging</span><span class="o">.</span><span class="n">get_average_model</span><span class="p">(</span><span class="n">net_for_averaging</span><span class="p">,</span>
-                                                                              <span class="n">validation_results_tuple</span><span class="o">=</span><span class="n">validation_results_tuple</span><span class="p">)</span>
-            <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">average_model_checkpoint_filename</span><span class="p">,</span>
-                                          <span class="n">state_dict</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;net&#39;</span><span class="p">:</span> <span class="n">averaged_model_sd</span><span class="p">},</span> <span class="n">global_step</span><span class="o">=</span><span class="n">epoch</span><span class="p">)</span>
+            <span class="n">averaged_model_sd</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_weight_averaging</span><span class="o">.</span><span class="n">get_average_model</span><span class="p">(</span><span class="n">net_for_averaging</span><span class="p">,</span> <span class="n">validation_results_tuple</span><span class="o">=</span><span class="n">validation_results_tuple</span><span class="p">)</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">average_model_checkpoint_filename</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span class="p">{</span><span class="s2">&quot;net&quot;</span><span class="p"
 
 
     <span class="k">def</span> <span class="nf">_save_best_checkpoint</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">state</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">_save_best_checkpoint</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">state</span><span class="p">):</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ckpt_best_name</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">global_step</span><span class="o
         <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_checkpoint</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ckpt_best_name</span><span class="p">,</span> <span class="n">state_dict</span><span class="o">=</span><span class="n">state</span><span class="p">,</span> <span class="n">global_step</span><span class="o
 
 
+    <span class="k">def</span> <span class="nf">_prep_net_for_train</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">arch_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">_init_arch_params</span><span class="p">()</span>
+
+        <span class="c1"># TODO: REMOVE THE BELOW LINE (FOR BACKWARD COMPATIBILITY)</span>
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span> <span class="o">=</span> <span class="n">HpmStruct</span><span class="p">(</span><span class="n">load_checkpoint</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">resume</span><span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">_net_to_device</span><span class="p">()</span>
+
+        <span class="c1"># SET THE FLAG FOR DIFFERENT PARAMETER GROUP OPTIMIZER UPDATE</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">update_param_groups</span> <span class="o">=</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="p">,</span> <span class="s2">&quot;update_param_groups&quot;</span><span class="p">)</span>
+
+        <span class="bp">self</span><span class="o">.</span><span class="n">checkpoint</span> <span class="o">=</span> <span class="p">{}</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">strict_load</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="p">,</span> <span class="s2">&quot;resume_strict_load&quot;</span><span class="p">,</span> <span class="n">StrictLoad</span><span class="o">.</span><span class="n"
+        <span class="bp">self</span><span class="o">.</span><span class="n">load_ema_as_net</span> <span class="o">=</span> <span class="kc">False</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">load_checkpoint</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="p">,</span> <span class="s2">&quot;resume&quot;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">external_checkpoint_path</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="p">,</span> <span class="s2">&quot;resume_path&quot;</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">load_checkpoint</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">load_checkpoint</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">external_checkpoint_path</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">ckpt_name</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">get_param</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="p">,</span> <span class="s2">&quot;ckpt_name&quot;</span><span class="p">,</span> <span class="s2">&quot;ckpt_latest.pth&quot;</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_load_checkpoint_to_model</span><span class="p">()</span>
+
+    <span class="k">def</span> <span class="nf">_init_arch_params</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+        <span class="n">default_arch_params</span> <span class="o">=</span> <span class="n">HpmStruct</span><span class="p">()</span>
+        <span class="n">arch_params</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">,</span> <span class="s2">&quot;arch_params&quot;</span><span class="p">,</span> <span class="n">default_arch_params</span><span class="p">)</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">arch_params</span> <span class="o">=</span> <span class="n">default_arch_params</span>
+        <span class="k">if</span> <span class="n">arch_params</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="bp">self</span><span class="o">.</span><span class="n">arch_params</span><span class="o">.</span><span class="n">override</span><span class="p">(</span><span class="o">**</span><span class="n">arch_params</span><span class="o">.</span><span class="n">to_dict</span><span class="p">())</span>
+
     <span class="c1"># FIXME - we need to resolve flake8&#39;s &#39;function is too complex&#39; for this function</span>
     <span class="c1"># FIXME - we need to resolve flake8&#39;s &#39;function is too complex&#39; for this function</span>
-<div class="viewcode-block" id="SgModel.train"><a class="viewcode-back" href="../../../../super_gradients.training.sg_model.html#super_gradients.training.SgModel.train">[docs]</a>    <span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">training_params</span><span class="p">:</span> <span class="nb">dict</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()):</span>  <span cla
+<div class="viewcode-block" id="Trainer.train"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.Trainer.train">[docs]</a>    <span class="k">def</span> <span class="nf">train</span><span class="p">(</span>
+        <span class="bp">self</span><span class="p">,</span>
+        <span class="n">model</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span>
+        <span class="n">training_params</span><span class="p">:</span> <span class="nb">dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">train_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">valid_loader</span><span class="p">:</span> <span class="n">DataLoader</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+        <span class="n">additional_configs_to_log</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="p">):</span>  <span class="c1"># noqa: C901</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 
 
 <span class="sd">        train - Trains the Model</span>
 <span class="sd">        train - Trains the Model</span>
@@ -596,8 +705,28 @@
 <span class="sd">          the data loaders, as dictionary. The phase context will hold the additional items, under an attribute with</span>
 <span class="sd">          the data loaders, as dictionary. The phase context will hold the additional items, under an attribute with</span>
 <span class="sd">          the same name as the key in this dictionary. Then such items can be accessed through phase callbacks.</span>
 <span class="sd">          the same name as the key in this dictionary. Then such items can be accessed through phase callbacks.</span>
 
 
+<span class="sd">            :param additional_configs_to_log: Dict, dictionary containing configs that will be added to the training&#39;s</span>
+<span class="sd">                sg_logger. Format should be {&quot;Config_title_1&quot;: {...}, &quot;Config_title_2&quot;:{..}}.</span>
+<span class="sd">            :param model: torch.nn.Module, model to train.</span>
 
 
+<span class="sd">            :param train_loader: Dataloader for train set.</span>
+<span class="sd">            :param valid_loader: Dataloader for validation.</span>
 <span class="sd">            :param training_params:</span>
 <span class="sd">            :param training_params:</span>
+
+<span class="sd">                - `resume` : bool (default=False)</span>
+
+<span class="sd">                    Whether to continue training from ckpt with the same experiment name</span>
+<span class="sd">                     (i.e resume from CKPT_ROOT_DIR/EXPERIMENT_NAME/CKPT_NAME)</span>
+
+<span class="sd">                - `ckpt_name` : str (default=ckpt_latest.pth)</span>
+
+<span class="sd">                    The checkpoint (.pth file) filename in CKPT_ROOT_DIR/EXPERIMENT_NAME/ to use when resume=True and</span>
+<span class="sd">                     resume_path=None</span>
+
+<span class="sd">                - `resume_path`: str (default=None)</span>
+
+<span class="sd">                    Explicit checkpoint path (.pth file) to use to resume training.</span>
+
 <span class="sd">                - `max_epochs` : int</span>
 <span class="sd">                - `max_epochs` : int</span>
 
 
 <span class="sd">                    Number of epochs to run training.</span>
 <span class="sd">                    Number of epochs to run training.</span>
@@ -659,10 +788,54 @@
 <span class="sd">                    where the computed loss is the sum of a few components we would like to log- these entries in</span>
 <span class="sd">                    where the computed loss is the sum of a few components we would like to log- these entries in</span>
 <span class="sd">                    loss_items).</span>
 <span class="sd">                    loss_items).</span>
 
 
-<span class="sd">                    When training, set the loss_logging_items_names parameter in train_params to be a list of</span>
-<span class="sd">                    strings, of length n_items who&#39;s ith element is the name of the ith entry in loss_items. Then</span>
-<span class="sd">                    each item will be logged, rendered on tensorboard and &quot;watched&quot; (i.e saving model checkpoints</span>
-<span class="sd">                    according to it).</span>
+<span class="sd">                    IMPORTANT:When dealing with external loss classes, to logg/monitor the loss_items as described</span>
+<span class="sd">                    above by specific string name:</span>
+
+<span class="sd">                    Set a &quot;component_names&quot; property in the loss class, whos instance is passed through train_params,</span>
+<span class="sd">                     to be a list of strings, of length n_items who&#39;s ith element is the name of the ith entry in loss_items.</span>
+<span class="sd">                     Then each item will be logged, rendered on tensorboard and &quot;watched&quot; (i.e saving model checkpoints</span>
+<span class="sd">                     according to it) under &lt;LOSS_CLASS.__name__&gt;&quot;/&quot;&lt;COMPONENT_NAME&gt;. If a single item is returned rather then a</span>
+<span class="sd">                     tuple, it would be logged under &lt;LOSS_CLASS.__name__&gt;. When there is no such attributed, the items</span>
+<span class="sd">                     will be named &lt;LOSS_CLASS.__name__&gt;&quot;/&quot;Loss_&quot;&lt;IDX&gt; according to the length of loss_items</span>
+
+<span class="sd">                    For example:</span>
+<span class="sd">                        class MyLoss(_Loss):</span>
+<span class="sd">                            ...</span>
+<span class="sd">                            def forward(self, inputs, targets):</span>
+<span class="sd">                                ...</span>
+<span class="sd">                                total_loss = comp1 + comp2</span>
+<span class="sd">                                loss_items = torch.cat((total_loss.unsqueeze(0),comp1.unsqueeze(0), comp2.unsqueeze(0)).detach()</span>
+<span class="sd">                                return total_loss, loss_items</span>
+<span class="sd">                            ...</span>
+<span class="sd">                            @property</span>
+<span class="sd">                            def component_names(self):</span>
+<span class="sd">                                return [&quot;total_loss&quot;, &quot;my_1st_component&quot;, &quot;my_2nd_component&quot;]</span>
+
+<span class="sd">                    Trainer.train(...</span>
+<span class="sd">                                    train_params={&quot;loss&quot;:MyLoss(),</span>
+<span class="sd">                                                    ...</span>
+<span class="sd">                                                    &quot;metric_to_watch&quot;: &quot;MyLoss/my_1st_component&quot;}</span>
+
+<span class="sd">                        This will write to log and monitor MyLoss/total_loss, MyLoss/my_1st_component,</span>
+<span class="sd">                         MyLoss/my_2nd_component.</span>
+
+<span class="sd">                   For example:</span>
+<span class="sd">                        class MyLoss2(_Loss):</span>
+<span class="sd">                            ...</span>
+<span class="sd">                            def forward(self, inputs, targets):</span>
+<span class="sd">                                ...</span>
+<span class="sd">                                total_loss = comp1 + comp2</span>
+<span class="sd">                                loss_items = torch.cat((total_loss.unsqueeze(0),comp1.unsqueeze(0), comp2.unsqueeze(0)).detach()</span>
+<span class="sd">                                return total_loss, loss_items</span>
+<span class="sd">                            ...</span>
+
+<span class="sd">                    Trainer.train(...</span>
+<span class="sd">                                    train_params={&quot;loss&quot;:MyLoss(),</span>
+<span class="sd">                                                    ...</span>
+<span class="sd">                                                    &quot;metric_to_watch&quot;: &quot;MyLoss2/loss_0&quot;}</span>
+
+<span class="sd">                        This will write to log and monitor MyLoss2/loss_0, MyLoss2/loss_1, MyLoss2/loss_2</span>
+<span class="sd">                        as they have been named by their positional index in loss_items.</span>
 
 
 <span class="sd">                    Since running logs will save the loss_items in some internal state, it is recommended that</span>
 <span class="sd">                    Since running logs will save the loss_items in some internal state, it is recommended that</span>
 <span class="sd">                    loss_items are detached from their computational graph for memory efficiency.</span>
 <span class="sd">                    loss_items are detached from their computational graph for memory efficiency.</span>
@@ -711,7 +884,7 @@
 <span class="sd">                        is a list referring to the names of each entry in the output metric (torch tensor of size n)</span>
 <span class="sd">                        is a list referring to the names of each entry in the output metric (torch tensor of size n)</span>
 
 
 <span class="sd">                        one of &quot;loss_logging_items_names&quot; i.e which will correspond to an item returned during the</span>
 <span class="sd">                        one of &quot;loss_logging_items_names&quot; i.e which will correspond to an item returned during the</span>
-<span class="sd">                        loss function&#39;s forward pass.</span>
+<span class="sd">                        loss function&#39;s forward pass (see loss docs abov).</span>
 
 
 <span class="sd">                    At the end of each epoch, if a new best metric_to_watch value is achieved, the models checkpoint</span>
 <span class="sd">                    At the end of each epoch, if a new best metric_to_watch value is achieved, the models checkpoint</span>
 <span class="sd">                    is saved in YOUR_PYTHON_PATH/checkpoints/ckpt_best.pth</span>
 <span class="sd">                    is saved in YOUR_PYTHON_PATH/checkpoints/ckpt_best.pth</span>
@@ -799,11 +972,6 @@
 <span class="sd">                    will be added to the tensorboard along with some sample images from the dataset. Currently only</span>
 <span class="sd">                    will be added to the tensorboard along with some sample images from the dataset. Currently only</span>
 <span class="sd">                    detection datasets are supported for analysis.</span>
 <span class="sd">                    detection datasets are supported for analysis.</span>
 
 
-<span class="sd">                -  `save_full_train_log` : bool (default=False)</span>
-
-<span class="sd">                    When set, a full log (of all super_gradients modules, including uncaught exceptions from any other</span>
-<span class="sd">                     module) of the training will be saved in the checkpoint directory under full_train_log.log</span>
-
 <span class="sd">                -  `sg_logger` : Union[AbstractSGLogger, str] (defauls=base_sg_logger)</span>
 <span class="sd">                -  `sg_logger` : Union[AbstractSGLogger, str] (defauls=base_sg_logger)</span>
 
 
 <span class="sd">                    Define the SGLogger object for this training process. The SGLogger handles all disk writes, logs, TensorBoard, remote logging</span>
 <span class="sd">                    Define the SGLogger object for this training process. The SGLogger handles all disk writes, logs, TensorBoard, remote logging</span>
@@ -855,49 +1023,49 @@
 
 
 <span class="sd">                        num_calib_batches: int, number of batches to collect the statistics from.</span>
 <span class="sd">                        num_calib_batches: int, number of batches to collect the statistics from.</span>
 
 
-<span class="sd">                        percentile: float, percentile value to use when SgModel,quant_modules_calib_method=&#39;percentile&#39;.</span>
+<span class="sd">                        percentile: float, percentile value to use when Trainer,quant_modules_calib_method=&#39;percentile&#39;.</span>
 <span class="sd">                         Discarded when other methods are used (Default=99.99).</span>
 <span class="sd">                         Discarded when other methods are used (Default=99.99).</span>
 
 
 
 
 <span class="sd">        :return:</span>
 <span class="sd">        :return:</span>
 <span class="sd">        &quot;&quot;&quot;</span>
 <span class="sd">        &quot;&quot;&quot;</span>
         <span class="k">global</span> <span class="n">logger</span>
         <span class="k">global</span> <span class="n">logger</span>
+        <span class="k">if</span> <span class="n">training_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">training_params</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span> <span class="o">=</span> <span class="n">train_loader</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span> <span class="o">=</span> <span class="n">valid_loader</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_set_dataset_params</span><span class="p">()</span>
 
 
-        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
-            <span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">&#39;Model&#39;</span><span class="p">,</span> <span class="s1">&#39;No model found&#39;</span><span class="p">)</span>
-        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_interface</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
-            <span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">&#39;Data&#39;</span><span class="p">,</span> <span class="s1">&#39;No dataset found&#39;</span><span class="p">)</span>
-
+        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">multi_gpu</span> <span class="o">==</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">DISTRIBUTED_DATA_PARALLEL</span><span class="p">:</span>
+            <span class="c1"># Note: the dataloader uses sampler of the batch_sampler when it is not None.</span>
+            <span class="n">train_sampler</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="o">.</span><span class="n">batch_sampler</span><span class="o">.</span><span class="n">sampler</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span><span class="o">.</span><span class="n">batch_sampler</span> <span class="ow">is</span> <span class="ow">not</
+            <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">train_sampler</span><span class="p">,</span> <span class="n">SequentialSampler</span><span class="p">):</span>
+                <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
+                    <span class="s2">&quot;You are using a SequentialSampler on you training dataloader, while working on DDP. &quot;</span>
+                    <span class="s2">&quot;This cancels the DDP benefits since it makes each process iterate through the entire dataset&quot;</span>
+                <span class="p">)</span>
+            <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">train_sampler</span><span class="p">,</span> <span class="p">(</span><span class="n">DistributedSampler</span><span class="p">,</span> <span class="n">InfiniteSampler</span><span class="p">,</span> <span class="n">RepeatAugSampler</span><span class="p">)):</span>
+                <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
+                    <span class="s2">&quot;The training sampler you are using might not support DDP. &quot;</span>
+                    <span class="s2">&quot;If it doesnt, please use one of the following sampler: DistributedSampler, InfiniteSampler, RepeatAugSampler&quot;</span>
+                <span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span> <span class="o">=</span> <span class="n">TrainingParams</span><span class="p">()</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span> <span class="o">=</span> <span class="n">TrainingParams</span><span class="p">()</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">override</span><span class="p">(</span><span class="o">**</span><span class="n">training_params</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">override</span><span class="p">(</span><span class="o">**</span><span class="n">training_params</span><span class="p">)</span>
 
 
+        <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">model</span>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_prep_net_for_train</span><span class="p">()</span>
+
         <span class="c1"># SET RANDOM SEED</span>
         <span class="c1"># SET RANDOM SEED</span>
-        <span class="n">random_seed</span><span class="p">(</span><span class="n">is_ddp</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">multi_gpu</span> <span class="o">==</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">DISTRIBUTED_DATA_PARALLEL</span><span class="p">,</span>
-                    <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span>
+        <span class="n">random_seed</span><span class="p">(</span><span class="n">is_ddp</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">multi_gpu</span> <span class="o">==</span> <span class="n">MultiGPUMode</span><span class="o">.</span><span class="n">DISTRIBUTED_DATA_PARALLEL</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><sp
 
 
         <span class="n">silent_mode</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">silent_mode</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span>
         <span class="n">silent_mode</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">silent_mode</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span>
         <span class="c1"># METRICS</span>
         <span class="c1"># METRICS</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">_set_train_metrics</span><span class="p">(</span><span class="n">train_metrics_list</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">train_metrics_list</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">_set_train_metrics</span><span class="p">(</span><span class="n">train_metrics_list</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">train_metrics_list</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">_set_valid_metrics</span><span class="p">(</span><span class="n">valid_metrics_list</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">valid_metrics_list</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">_set_valid_metrics</span><span class="p">(</span><span class="n">valid_metrics_list</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">valid_metrics_list</span><span class="p">)</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">loss_logging_items_names</span>
-
-        <span class="bp">self</span><span class="o">.</span><span class="n">results_titles</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;Train_&quot;</span> <span class="o">+</span> <span class="n">t</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span>
-                               <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="o">+</span> <span class="n">get_metrics_titles</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">train_metrics</span><span class="p">)]</span> <span class="o">+</span> \
-                              <span class="p">[</span><span class="s2">&quot;Valid_&quot;</span> <span class="o">+</span> <span class="n">t</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span>
-                               <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="o">+</span> <span class="n">get_metrics_titles</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_metrics</span><span class="p">)]</span>
 
 
         <span class="c1"># Store the metric to follow (loss\accuracy) and initialize as the worst value</span>
         <span class="c1"># Store the metric to follow (loss\accuracy) and initialize as the worst value</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">metric_to_watch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">metric_to_watch</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">metric_to_watch</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">metric_to_watch</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">greater_metric_to_watch_is_better</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">greater_metric_to_watch_is_better</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">greater_metric_to_watch_is_better</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">greater_metric_to_watch_is_better</span>
-        <span class="bp">self</span><span class="o">.</span><span class="n">metric_idx_in_results_tuple</span> <span class="o">=</span> <span class="p">(</span>
-            <span class="bp">self</span><span class="o">.</span><span class="n">loss_logging_items_names</span> <span class="o">+</span> <span class="n">get_metrics_titles</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">valid_metrics</span><span class="p">))</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_to_watch</span><span class="p">)<
-
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.training_hyperparams.training_hyperparams &mdash; SuperGradients 3.0.3 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
  11. <!--[if lt IE 9]>
  12. <script src="../../../../_static/js/html5shiv.min.js"></script>
  13. <![endif]-->
  14. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  15. <script src="../../../../_static/jquery.js"></script>
  16. <script src="../../../../_static/underscore.js"></script>
  17. <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
  18. <script src="../../../../_static/doctools.js"></script>
  19. <script src="../../../../_static/sphinx_highlight.js"></script>
  20. <script src="../../../../_static/js/theme.js"></script>
  21. <link rel="index" title="Index" href="../../../../genindex.html" />
  22. <link rel="search" title="Search" href="../../../../search.html" />
  23. </head>
  24. <body class="wy-body-for-nav">
  25. <div class="wy-grid-for-nav">
  26. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  27. <div class="wy-side-scroll">
  28. <div class="wy-side-nav-search" >
  29. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  30. </a>
  31. <div role="search">
  32. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  33. <input type="text" name="q" placeholder="Search docs" />
  34. <input type="hidden" name="check_keywords" value="yes" />
  35. <input type="hidden" name="area" value="default" />
  36. </form>
  37. </div>
  38. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  39. <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
  40. <ul>
  41. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
  42. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
  45. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
  46. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
  47. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
  57. </ul>
  58. <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
  59. <ul>
  60. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  61. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  62. </ul>
  63. </div>
  64. </div>
  65. </nav>
  66. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  67. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  68. <a href="../../../../index.html">SuperGradients</a>
  69. </nav>
  70. <div class="wy-nav-content">
  71. <div class="rst-content">
  72. <div role="navigation" aria-label="Page navigation">
  73. <ul class="wy-breadcrumbs">
  74. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  75. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  76. <li>super_gradients.training.training_hyperparams.training_hyperparams</li>
  77. <li class="wy-breadcrumbs-aside">
  78. </li>
  79. </ul>
  80. <hr/>
  81. </div>
  82. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  83. <div itemprop="articleBody">
  84. <h1>Source code for super_gradients.training.training_hyperparams.training_hyperparams</h1><div class="highlight"><pre>
  85. <span></span><span class="kn">import</span> <span class="nn">hydra</span>
  86. <span class="kn">import</span> <span class="nn">pkg_resources</span>
  87. <span class="kn">from</span> <span class="nn">hydra</span> <span class="kn">import</span> <span class="n">compose</span><span class="p">,</span> <span class="n">initialize_config_dir</span>
  88. <span class="kn">from</span> <span class="nn">hydra.core.global_hydra</span> <span class="kn">import</span> <span class="n">GlobalHydra</span>
  89. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.utils</span> <span class="kn">import</span> <span class="n">override_default_params_without_nones</span>
  90. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.hydra_utils</span> <span class="kn">import</span> <span class="n">normalize_path</span>
  91. <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
  92. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span>
  93. <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
  94. <div class="viewcode-block" id="get"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.get">[docs]</a><span class="k">def</span> <span class="nf">get</span><span class="p">(</span><span class="n">config_name</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">:</span>
  95. <span class="sd">&quot;&quot;&quot;</span>
  96. <span class="sd"> Class for creating training hyper parameters dictionary, taking defaults from yaml</span>
  97. <span class="sd"> files in src/super_gradients/recipes.</span>
  98. <span class="sd"> :param overriding_params: Dict, dictionary like object containing entries to override in the recipe&#39;s training</span>
  99. <span class="sd"> hyper parameters dictionary.</span>
  100. <span class="sd"> :param config_name: yaml config filename in recipes (for example coco2017_yolox).</span>
  101. <span class="sd"> &quot;&quot;&quot;</span>
  102. <span class="k">if</span> <span class="n">overriding_params</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  103. <span class="n">overriding_params</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
  104. <span class="n">GlobalHydra</span><span class="o">.</span><span class="n">instance</span><span class="p">()</span><span class="o">.</span><span class="n">clear</span><span class="p">()</span>
  105. <span class="n">sg_recipes_dir</span> <span class="o">=</span> <span class="n">pkg_resources</span><span class="o">.</span><span class="n">resource_filename</span><span class="p">(</span><span class="s2">&quot;super_gradients.recipes&quot;</span><span class="p">,</span> <span class="s2">&quot;&quot;</span><span class="p">)</span>
  106. <span class="k">with</span> <span class="n">initialize_config_dir</span><span class="p">(</span><span class="n">config_dir</span><span class="o">=</span><span class="n">normalize_path</span><span class="p">(</span><span class="n">sg_recipes_dir</span><span class="p">),</span> <span class="n">version_base</span><span class="o">=</span><span class="s2">&quot;1.2&quot;</span><span class="p">):</span>
  107. <span class="n">cfg</span> <span class="o">=</span> <span class="n">compose</span><span class="p">(</span><span class="n">config_name</span><span class="o">=</span><span class="n">normalize_path</span><span class="p">(</span><span class="n">config_name</span><span class="p">))</span>
  108. <span class="n">cfg</span> <span class="o">=</span> <span class="n">hydra</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">instantiate</span><span class="p">(</span><span class="n">cfg</span><span class="p">)</span>
  109. <span class="n">training_params</span> <span class="o">=</span> <span class="n">cfg</span><span class="o">.</span><span class="n">training_hyperparams</span>
  110. <span class="n">training_params</span> <span class="o">=</span> <span class="n">override_default_params_without_nones</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">,</span> <span class="n">training_params</span><span class="p">)</span>
  111. <span class="k">return</span> <span class="n">training_params</span></div>
  112. <div class="viewcode-block" id="cifar10_resnet_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.cifar10_resnet_train_params">[docs]</a><span class="k">def</span> <span class="nf">cifar10_resnet_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  113. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;cifar10_resnet&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
  114. <div class="viewcode-block" id="cityscapes_ddrnet_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.cityscapes_ddrnet_train_params">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_ddrnet_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  115. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;cityscapes_ddrnet&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
  116. <div class="viewcode-block" id="cityscapes_regseg48_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.cityscapes_regseg48_train_params">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_regseg48_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  117. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;cityscapes_regseg48&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
  118. <div class="viewcode-block" id="cityscapes_stdc_base_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.cityscapes_stdc_base_train_params">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_stdc_base_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  119. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;cityscapes_stdc_base&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
  120. <div class="viewcode-block" id="cityscapes_stdc_seg50_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.cityscapes_stdc_seg50_train_params">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_stdc_seg50_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  121. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;cityscapes_stdc_seg50&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
  122. <div class="viewcode-block" id="cityscapes_stdc_seg75_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.cityscapes_stdc_seg75_train_params">[docs]</a><span class="k">def</span> <span class="nf">cityscapes_stdc_seg75_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  123. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;cityscapes_stdc_seg75&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
  124. <div class="viewcode-block" id="coco2017_ssd_lite_mobilenet_v2_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.coco2017_ssd_lite_mobilenet_v2_train_params">[docs]</a><span class="k">def</span> <span class="nf">coco2017_ssd_lite_mobilenet_v2_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  125. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;coco2017_ssd_lite_mobilenet_v2&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
  126. <div class="viewcode-block" id="coco2017_yolox_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.coco2017_yolox_train_params">[docs]</a><span class="k">def</span> <span class="nf">coco2017_yolox_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  127. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;coco2017_yolox&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
  128. <div class="viewcode-block" id="coco_segmentation_shelfnet_lw_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.coco_segmentation_shelfnet_lw_train_params">[docs]</a><span class="k">def</span> <span class="nf">coco_segmentation_shelfnet_lw_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  129. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;coco_segmentation_shelfnet_lw&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
  130. <div class="viewcode-block" id="imagenet_efficientnet_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_efficientnet_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_efficientnet_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  131. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_efficientnet&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
  132. <div class="viewcode-block" id="imagenet_mobilenetv2_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_mobilenetv2_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_mobilenetv2_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  133. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_mobilenetv2&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
  134. <div class="viewcode-block" id="imagenet_mobilenetv3_base_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_mobilenetv3_base_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_mobilenetv3_base_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  135. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_mobilenetv3_base&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
  136. <div class="viewcode-block" id="imagenet_mobilenetv3_large_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_mobilenetv3_large_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_mobilenetv3_large_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  137. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_mobilenetv3_large&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
  138. <div class="viewcode-block" id="imagenet_mobilenetv3_small_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_mobilenetv3_small_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_mobilenetv3_small_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  139. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_mobilenetv3_small&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
  140. <div class="viewcode-block" id="imagenet_regnetY_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_regnetY_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_regnetY_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  141. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_regnetY&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
  142. <div class="viewcode-block" id="imagenet_repvgg_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_repvgg_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_repvgg_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  143. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_repvgg&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
  144. <div class="viewcode-block" id="imagenet_resnet50_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_resnet50_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_resnet50_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  145. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_resnet50&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
  146. <div class="viewcode-block" id="imagenet_resnet50_kd_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_resnet50_kd_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_resnet50_kd_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  147. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_resnet50_kd&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
  148. <div class="viewcode-block" id="imagenet_vit_base_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_vit_base_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_vit_base_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  149. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_vit_base&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
  150. <div class="viewcode-block" id="imagenet_vit_large_train_params"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.training_hyperparams.imagenet_vit_large_train_params">[docs]</a><span class="k">def</span> <span class="nf">imagenet_vit_large_train_params</span><span class="p">(</span><span class="n">overriding_params</span><span class="p">:</span> <span class="n">Dict</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  151. <span class="k">return</span> <span class="n">get</span><span class="p">(</span><span class="s2">&quot;imagenet_vit_large&quot;</span><span class="p">,</span> <span class="n">overriding_params</span><span class="p">)</span></div>
  152. </pre></div>
  153. </div>
  154. </div>
  155. <footer>
  156. <hr/>
  157. <div role="contentinfo">
  158. <p>&#169; Copyright 2021, SuperGradients team.</p>
  159. </div>
  160. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  161. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  162. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  163. </footer>
  164. </div>
  165. </div>
  166. </section>
  167. </div>
  168. <script>
  169. jQuery(function () {
  170. SphinxRtdTheme.Navigation.enable(true);
  171. });
  172. </script>
  173. </body>
  174. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.transforms.transforms &mdash; SuperGradients 3.0.3 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
  11. <!--[if lt IE 9]>
  12. <script src="../../../../_static/js/html5shiv.min.js"></script>
  13. <![endif]-->
  14. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  15. <script src="../../../../_static/jquery.js"></script>
  16. <script src="../../../../_static/underscore.js"></script>
  17. <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
  18. <script src="../../../../_static/doctools.js"></script>
  19. <script src="../../../../_static/sphinx_highlight.js"></script>
  20. <script src="../../../../_static/js/theme.js"></script>
  21. <link rel="index" title="Index" href="../../../../genindex.html" />
  22. <link rel="search" title="Search" href="../../../../search.html" />
  23. </head>
  24. <body class="wy-body-for-nav">
  25. <div class="wy-grid-for-nav">
  26. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  27. <div class="wy-side-scroll">
  28. <div class="wy-side-nav-search" >
  29. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  30. </a>
  31. <div role="search">
  32. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  33. <input type="text" name="q" placeholder="Search docs" />
  34. <input type="hidden" name="check_keywords" value="yes" />
  35. <input type="hidden" name="area" value="default" />
  36. </form>
  37. </div>
  38. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  39. <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
  40. <ul>
  41. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
  42. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
  45. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
  46. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
  47. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
  57. </ul>
  58. <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
  59. <ul>
  60. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  61. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  62. </ul>
  63. </div>
  64. </div>
  65. </nav>
  66. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  67. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  68. <a href="../../../../index.html">SuperGradients</a>
  69. </nav>
  70. <div class="wy-nav-content">
  71. <div class="rst-content">
  72. <div role="navigation" aria-label="Page navigation">
  73. <ul class="wy-breadcrumbs">
  74. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  75. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  76. <li>super_gradients.training.transforms.transforms</li>
  77. <li class="wy-breadcrumbs-aside">
  78. </li>
  79. </ul>
  80. <hr/>
  81. </div>
  82. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  83. <div itemprop="articleBody">
  84. <h1>Source code for super_gradients.training.transforms.transforms</h1><div class="highlight"><pre>
  85. <span></span><span class="kn">import</span> <span class="nn">collections</span>
  86. <span class="kn">import</span> <span class="nn">math</span>
  87. <span class="kn">import</span> <span class="nn">random</span>
  88. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">,</span> <span class="n">Dict</span>
  89. <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span><span class="p">,</span> <span class="n">ImageFilter</span><span class="p">,</span> <span class="n">ImageOps</span>
  90. <span class="kn">from</span> <span class="nn">torchvision</span> <span class="kn">import</span> <span class="n">transforms</span> <span class="k">as</span> <span class="n">transforms</span>
  91. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  92. <span class="kn">import</span> <span class="nn">cv2</span>
  93. <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
  94. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">get_mosaic_coordinate</span><span class="p">,</span> <span class="n">adjust_box_anns</span><span class="p">,</span> <span class="n">xyxy2cxcywh</span><span class="p">,</span> <span class="n">cxcywh2xyxy</span><span class="p">,</span> <span class="n">DetectionTargetsFormat</span>
  95. <span class="n">image_resample</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">BILINEAR</span>
  96. <span class="n">mask_resample</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">NEAREST</span>
  97. <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
  98. <span class="k">class</span> <span class="nc">SegmentationTransform</span><span class="p">:</span>
  99. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  100. <span class="k">raise</span> <span class="ne">NotImplementedError</span>
  101. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  102. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;{&quot;</span><span class="p">,</span> <span class="s2">&quot;(&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;}&quot;</span><span class="p">,</span> <span class="s2">&quot;)&quot;</span><span class="p">)</span>
  103. <span class="k">class</span> <span class="nc">SegResize</span><span class="p">(</span><span class="n">SegmentationTransform</span><span class="p">):</span>
  104. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span><span class="p">):</span>
  105. <span class="bp">self</span><span class="o">.</span><span class="n">h</span> <span class="o">=</span> <span class="n">h</span>
  106. <span class="bp">self</span><span class="o">.</span><span class="n">w</span> <span class="o">=</span> <span class="n">w</span>
  107. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">):</span>
  108. <span class="n">image</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span>
  109. <span class="n">mask</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span>
  110. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">w</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">h</span><span class="p">),</span> <span class="n">image_resample</span><span class="p">)</span>
  111. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">resize</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">w</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">h</span><span class="p">),</span> <span class="n">mask_resample</span><span class="p">)</span>
  112. <span class="k">return</span> <span class="n">sample</span>
  113. <span class="k">class</span> <span class="nc">SegRandomFlip</span><span class="p">(</span><span class="n">SegmentationTransform</span><span class="p">):</span>
  114. <span class="sd">&quot;&quot;&quot;</span>
  115. <span class="sd"> Randomly flips the image and mask (synchronously) with probability &#39;prob&#39;.</span>
  116. <span class="sd"> &quot;&quot;&quot;</span>
  117. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">):</span>
  118. <span class="k">assert</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">prob</span> <span class="o">&lt;=</span> <span class="mf">1.0</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;Probability value must be between 0 and 1, found </span><span class="si">{</span><span class="n">prob</span><span class="si">}</span><span class="s2">&quot;</span>
  119. <span class="bp">self</span><span class="o">.</span><span class="n">prob</span> <span class="o">=</span> <span class="n">prob</span>
  120. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
  121. <span class="n">image</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span>
  122. <span class="n">mask</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span>
  123. <span class="k">if</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">prob</span><span class="p">:</span>
  124. <span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">Image</span><span class="o">.</span><span class="n">FLIP_LEFT_RIGHT</span><span class="p">)</span>
  125. <span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">Image</span><span class="o">.</span><span class="n">FLIP_LEFT_RIGHT</span><span class="p">)</span>
  126. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">image</span>
  127. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mask</span>
  128. <span class="k">return</span> <span class="n">sample</span>
  129. <span class="k">class</span> <span class="nc">SegRescale</span><span class="p">(</span><span class="n">SegmentationTransform</span><span class="p">):</span>
  130. <span class="sd">&quot;&quot;&quot;</span>
  131. <span class="sd"> Rescales the image and mask (synchronously) while preserving aspect ratio.</span>
  132. <span class="sd"> The rescaling can be done according to scale_factor, short_size or long_size.</span>
  133. <span class="sd"> If more than one argument is given, the rescaling mode is determined by this order: scale_factor, then short_size,</span>
  134. <span class="sd"> then long_size.</span>
  135. <span class="sd"> Args:</span>
  136. <span class="sd"> scale_factor: rescaling is done by multiplying input size by scale_factor:</span>
  137. <span class="sd"> out_size = (scale_factor * w, scale_factor * h)</span>
  138. <span class="sd"> short_size: rescaling is done by determining the scale factor by the ratio short_size / min(h, w).</span>
  139. <span class="sd"> long_size: rescaling is done by determining the scale factor by the ratio long_size / max(h, w).</span>
  140. <span class="sd"> &quot;&quot;&quot;</span>
  141. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">scale_factor</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">short_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">long_size</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  142. <span class="bp">self</span><span class="o">.</span><span class="n">scale_factor</span> <span class="o">=</span> <span class="n">scale_factor</span>
  143. <span class="bp">self</span><span class="o">.</span><span class="n">short_size</span> <span class="o">=</span> <span class="n">short_size</span>
  144. <span class="bp">self</span><span class="o">.</span><span class="n">long_size</span> <span class="o">=</span> <span class="n">long_size</span>
  145. <span class="bp">self</span><span class="o">.</span><span class="n">check_valid_arguments</span><span class="p">()</span>
  146. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
  147. <span class="n">image</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span>
  148. <span class="n">mask</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span>
  149. <span class="n">w</span><span class="p">,</span> <span class="n">h</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">size</span>
  150. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale_factor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  151. <span class="n">scale</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale_factor</span>
  152. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">short_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  153. <span class="n">short_size</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span>
  154. <span class="n">scale</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">short_size</span> <span class="o">/</span> <span class="n">short_size</span>
  155. <span class="k">else</span><span class="p">:</span>
  156. <span class="n">long_size</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span>
  157. <span class="n">scale</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">long_size</span> <span class="o">/</span> <span class="n">long_size</span>
  158. <span class="n">out_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">scale</span> <span class="o">*</span> <span class="n">w</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">scale</span> <span class="o">*</span> <span class="n">h</span><span class="p">)</span>
  159. <span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">out_size</span><span class="p">,</span> <span class="n">image_resample</span><span class="p">)</span>
  160. <span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">out_size</span><span class="p">,</span> <span class="n">mask_resample</span><span class="p">)</span>
  161. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">image</span>
  162. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mask</span>
  163. <span class="k">return</span> <span class="n">sample</span>
  164. <span class="k">def</span> <span class="nf">check_valid_arguments</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  165. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale_factor</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">short_size</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">long_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  166. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Must assign one rescale argument: scale_factor, short_size or long_size&quot;</span><span class="p">)</span>
  167. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale_factor</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale_factor</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">:</span>
  168. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Scale factor must be a positive number, found: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">scale_factor</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
  169. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">short_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">short_size</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">:</span>
  170. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Short size must be a positive number, found: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">short_size</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
  171. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">long_size</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">long_size</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">:</span>
  172. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Long size must be a positive number, found: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">long_size</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
  173. <span class="k">class</span> <span class="nc">SegRandomRescale</span><span class="p">:</span>
  174. <span class="sd">&quot;&quot;&quot;</span>
  175. <span class="sd"> Random rescale the image and mask (synchronously) while preserving aspect ratio.</span>
  176. <span class="sd"> Scale factor is randomly picked between scales [min, max]</span>
  177. <span class="sd"> Args:</span>
  178. <span class="sd"> scales: scale range tuple (min, max), if scales is a float range will be defined as (1, scales) if scales &gt; 1,</span>
  179. <span class="sd"> otherwise (scales, 1). must be a positive number.</span>
  180. <span class="sd"> &quot;&quot;&quot;</span>
  181. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">scales</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">List</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">2.0</span><span class="p">)):</span>
  182. <span class="bp">self</span><span class="o">.</span><span class="n">scales</span> <span class="o">=</span> <span class="n">scales</span>
  183. <span class="bp">self</span><span class="o">.</span><span class="n">check_valid_arguments</span><span class="p">()</span>
  184. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
  185. <span class="n">image</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span>
  186. <span class="n">mask</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span>
  187. <span class="n">w</span><span class="p">,</span> <span class="n">h</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">size</span>
  188. <span class="n">scale</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
  189. <span class="n">out_size</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">scale</span> <span class="o">*</span> <span class="n">w</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">scale</span> <span class="o">*</span> <span class="n">h</span><span class="p">)</span>
  190. <span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">out_size</span><span class="p">,</span> <span class="n">image_resample</span><span class="p">)</span>
  191. <span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">out_size</span><span class="p">,</span> <span class="n">mask_resample</span><span class="p">)</span>
  192. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">image</span>
  193. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mask</span>
  194. <span class="k">return</span> <span class="n">sample</span>
  195. <span class="k">def</span> <span class="nf">check_valid_arguments</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  196. <span class="sd">&quot;&quot;&quot;</span>
  197. <span class="sd"> Check the scale values are valid. if order is wrong, flip the order and return the right scale values.</span>
  198. <span class="sd"> &quot;&quot;&quot;</span>
  199. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">,</span> <span class="n">collections</span><span class="o">.</span><span class="n">abc</span><span class="o">.</span><span class="n">Iterable</span><span class="p">):</span>
  200. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scales</span> <span class="o">&lt;=</span> <span class="mi">1</span><span class="p">:</span>
  201. <span class="bp">self</span><span class="o">.</span><span class="n">scales</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
  202. <span class="k">else</span><span class="p">:</span>
  203. <span class="bp">self</span><span class="o">.</span><span class="n">scales</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">)</span>
  204. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&lt;</span> <span class="mi">0</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">:</span>
  205. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;SegRandomRescale scale values must be positive numbers, found: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
  206. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">[</span><span class="mi">1</span><span class="p">]:</span>
  207. <span class="bp">self</span><span class="o">.</span><span class="n">scales</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">scales</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
  208. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">scales</span>
  209. <span class="k">class</span> <span class="nc">SegRandomRotate</span><span class="p">(</span><span class="n">SegmentationTransform</span><span class="p">):</span>
  210. <span class="sd">&quot;&quot;&quot;</span>
  211. <span class="sd"> Randomly rotates image and mask (synchronously) between &#39;min_deg&#39; and &#39;max_deg&#39;.</span>
  212. <span class="sd"> &quot;&quot;&quot;</span>
  213. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">min_deg</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">-</span><span class="mi">10</span><span class="p">,</span> <span class="n">max_deg</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> <span class="n">fill_mask</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">fill_image</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">List</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span>
  214. <span class="bp">self</span><span class="o">.</span><span class="n">min_deg</span> <span class="o">=</span> <span class="n">min_deg</span>
  215. <span class="bp">self</span><span class="o">.</span><span class="n">max_deg</span> <span class="o">=</span> <span class="n">max_deg</span>
  216. <span class="bp">self</span><span class="o">.</span><span class="n">fill_mask</span> <span class="o">=</span> <span class="n">fill_mask</span>
  217. <span class="c1"># grey color in RGB mode</span>
  218. <span class="bp">self</span><span class="o">.</span><span class="n">fill_image</span> <span class="o">=</span> <span class="p">(</span><span class="n">fill_image</span><span class="p">,</span> <span class="n">fill_image</span><span class="p">,</span> <span class="n">fill_image</span><span class="p">)</span>
  219. <span class="bp">self</span><span class="o">.</span><span class="n">check_valid_arguments</span><span class="p">()</span>
  220. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
  221. <span class="n">image</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span>
  222. <span class="n">mask</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span>
  223. <span class="n">deg</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">min_deg</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_deg</span><span class="p">)</span>
  224. <span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">rotate</span><span class="p">(</span><span class="n">deg</span><span class="p">,</span> <span class="n">resample</span><span class="o">=</span><span class="n">image_resample</span><span class="p">,</span> <span class="n">fillcolor</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">fill_image</span><span class="p">)</span>
  225. <span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">rotate</span><span class="p">(</span><span class="n">deg</span><span class="p">,</span> <span class="n">resample</span><span class="o">=</span><span class="n">mask_resample</span><span class="p">,</span> <span class="n">fillcolor</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">fill_mask</span><span class="p">)</span>
  226. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">image</span>
  227. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mask</span>
  228. <span class="k">return</span> <span class="n">sample</span>
  229. <span class="k">def</span> <span class="nf">check_valid_arguments</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  230. <span class="bp">self</span><span class="o">.</span><span class="n">fill_mask</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">fill_image</span> <span class="o">=</span> <span class="n">_validate_fill_values_arguments</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fill_mask</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">fill_image</span><span class="p">)</span>
  231. <span class="k">class</span> <span class="nc">SegCropImageAndMask</span><span class="p">(</span><span class="n">SegmentationTransform</span><span class="p">):</span>
  232. <span class="sd">&quot;&quot;&quot;</span>
  233. <span class="sd"> Crops image and mask (synchronously).</span>
  234. <span class="sd"> In &quot;center&quot; mode a center crop is performed while, in &quot;random&quot; mode the drop will be positioned around</span>
  235. <span class="sd"> random coordinates.</span>
  236. <span class="sd"> &quot;&quot;&quot;</span>
  237. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">crop_size</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">List</span><span class="p">],</span> <span class="n">mode</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
  238. <span class="sd">&quot;&quot;&quot;</span>
  239. <span class="sd"> :param crop_size: tuple of (width, height) for the final crop size, if is scalar size is a</span>
  240. <span class="sd"> square (crop_size, crop_size)</span>
  241. <span class="sd"> :param mode: how to choose the center of the crop, &#39;center&#39; for the center of the input image,</span>
  242. <span class="sd"> &#39;random&#39; center the point is chosen randomally</span>
  243. <span class="sd"> &quot;&quot;&quot;</span>
  244. <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span> <span class="o">=</span> <span class="n">crop_size</span>
  245. <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">=</span> <span class="n">mode</span>
  246. <span class="bp">self</span><span class="o">.</span><span class="n">check_valid_arguments</span><span class="p">()</span>
  247. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
  248. <span class="n">image</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span>
  249. <span class="n">mask</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span>
  250. <span class="n">w</span><span class="p">,</span> <span class="n">h</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">size</span>
  251. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">==</span> <span class="s2">&quot;random&quot;</span><span class="p">:</span>
  252. <span class="n">x1</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">w</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
  253. <span class="n">y1</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">h</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
  254. <span class="k">else</span><span class="p">:</span>
  255. <span class="n">x1</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="nb">round</span><span class="p">((</span><span class="n">w</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">/</span> <span class="mf">2.0</span><span class="p">))</span>
  256. <span class="n">y1</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="nb">round</span><span class="p">((</span><span class="n">h</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span> <span class="o">/</span> <span class="mf">2.0</span><span class="p">))</span>
  257. <span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">crop</span><span class="p">((</span><span class="n">x1</span><span class="p">,</span> <span class="n">y1</span><span class="p">,</span> <span class="n">x1</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">y1</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
  258. <span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">crop</span><span class="p">((</span><span class="n">x1</span><span class="p">,</span> <span class="n">y1</span><span class="p">,</span> <span class="n">x1</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">y1</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
  259. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">image</span>
  260. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mask</span>
  261. <span class="k">return</span> <span class="n">sample</span>
  262. <span class="k">def</span> <span class="nf">check_valid_arguments</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  263. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;center&quot;</span><span class="p">,</span> <span class="s2">&quot;random&quot;</span><span class="p">]:</span>
  264. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Unsupported mode: found: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="si">}</span><span class="s2">, expected: &#39;center&#39; or &#39;random&#39;&quot;</span><span class="p">)</span>
  265. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">,</span> <span class="n">collections</span><span class="o">.</span><span class="n">abc</span><span class="o">.</span><span class="n">Iterable</span><span class="p">):</span>
  266. <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">)</span>
  267. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&lt;=</span> <span class="mi">0</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">:</span>
  268. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Crop size must be positive numbers, found: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
  269. <span class="k">class</span> <span class="nc">SegRandomGaussianBlur</span><span class="p">(</span><span class="n">SegmentationTransform</span><span class="p">):</span>
  270. <span class="sd">&quot;&quot;&quot;</span>
  271. <span class="sd"> Adds random Gaussian Blur to image with probability &#39;prob&#39;.</span>
  272. <span class="sd"> &quot;&quot;&quot;</span>
  273. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">):</span>
  274. <span class="k">assert</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">prob</span> <span class="o">&lt;=</span> <span class="mf">1.0</span><span class="p">,</span> <span class="s2">&quot;Probability value must be between 0 and 1&quot;</span>
  275. <span class="bp">self</span><span class="o">.</span><span class="n">prob</span> <span class="o">=</span> <span class="n">prob</span>
  276. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
  277. <span class="n">image</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span>
  278. <span class="n">mask</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span>
  279. <span class="k">if</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">prob</span><span class="p">:</span>
  280. <span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">ImageFilter</span><span class="o">.</span><span class="n">GaussianBlur</span><span class="p">(</span><span class="n">radius</span><span class="o">=</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()))</span>
  281. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">image</span>
  282. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mask</span>
  283. <span class="k">return</span> <span class="n">sample</span>
  284. <span class="k">class</span> <span class="nc">SegPadShortToCropSize</span><span class="p">(</span><span class="n">SegmentationTransform</span><span class="p">):</span>
  285. <span class="sd">&quot;&quot;&quot;</span>
  286. <span class="sd"> Pads image to &#39;crop_size&#39;.</span>
  287. <span class="sd"> Should be called only after &quot;SegRescale&quot; or &quot;SegRandomRescale&quot; in augmentations pipeline.</span>
  288. <span class="sd"> &quot;&quot;&quot;</span>
  289. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">crop_size</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">List</span><span class="p">],</span> <span class="n">fill_mask</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">fill_image</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">List</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span>
  290. <span class="sd">&quot;&quot;&quot;</span>
  291. <span class="sd"> :param crop_size: tuple of (width, height) for the final crop size, if is scalar size is a</span>
  292. <span class="sd"> square (crop_size, crop_size)</span>
  293. <span class="sd"> :param fill_mask: value to fill mask labels background.</span>
  294. <span class="sd"> :param fill_image: grey value to fill image padded background.</span>
  295. <span class="sd"> &quot;&quot;&quot;</span>
  296. <span class="c1"># CHECK IF CROP SIZE IS A ITERABLE OR SCALAR</span>
  297. <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span> <span class="o">=</span> <span class="n">crop_size</span>
  298. <span class="bp">self</span><span class="o">.</span><span class="n">fill_mask</span> <span class="o">=</span> <span class="n">fill_mask</span>
  299. <span class="bp">self</span><span class="o">.</span><span class="n">fill_image</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">fill_image</span><span class="p">)</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">fill_image</span><span class="p">,</span> <span class="n">Sequence</span><span class="p">)</span> <span class="k">else</span> <span class="n">fill_image</span>
  300. <span class="bp">self</span><span class="o">.</span><span class="n">check_valid_arguments</span><span class="p">()</span>
  301. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
  302. <span class="n">image</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span>
  303. <span class="n">mask</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span>
  304. <span class="n">w</span><span class="p">,</span> <span class="n">h</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">size</span>
  305. <span class="c1"># pad images from center symmetrically</span>
  306. <span class="k">if</span> <span class="n">w</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="ow">or</span> <span class="n">h</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]:</span>
  307. <span class="n">padh</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">h</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">h</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">else</span> <span class="mi">0</span>
  308. <span class="n">pad_top</span><span class="p">,</span> <span class="n">pad_bottom</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">padh</span><span class="p">),</span> <span class="n">math</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">padh</span><span class="p">)</span>
  309. <span class="n">padw</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">w</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">w</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">else</span> <span class="mi">0</span>
  310. <span class="n">pad_left</span><span class="p">,</span> <span class="n">pad_right</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">padw</span><span class="p">),</span> <span class="n">math</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">padw</span><span class="p">)</span>
  311. <span class="n">image</span> <span class="o">=</span> <span class="n">ImageOps</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">border</span><span class="o">=</span><span class="p">(</span><span class="n">pad_left</span><span class="p">,</span> <span class="n">pad_top</span><span class="p">,</span> <span class="n">pad_right</span><span class="p">,</span> <span class="n">pad_bottom</span><span class="p">),</span> <span class="n">fill</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">fill_image</span><span class="p">)</span>
  312. <span class="n">mask</span> <span class="o">=</span> <span class="n">ImageOps</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">border</span><span class="o">=</span><span class="p">(</span><span class="n">pad_left</span><span class="p">,</span> <span class="n">pad_top</span><span class="p">,</span> <span class="n">pad_right</span><span class="p">,</span> <span class="n">pad_bottom</span><span class="p">),</span> <span class="n">fill</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">fill_mask</span><span class="p">)</span>
  313. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">image</span>
  314. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;mask&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mask</span>
  315. <span class="k">return</span> <span class="n">sample</span>
  316. <span class="k">def</span> <span class="nf">check_valid_arguments</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  317. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">,</span> <span class="n">collections</span><span class="o">.</span><span class="n">abc</span><span class="o">.</span><span class="n">Iterable</span><span class="p">):</span>
  318. <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">)</span>
  319. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&lt;=</span> <span class="mi">0</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">:</span>
  320. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Crop size must be positive numbers, found: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">crop_size</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
  321. <span class="bp">self</span><span class="o">.</span><span class="n">fill_mask</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">fill_image</span> <span class="o">=</span> <span class="n">_validate_fill_values_arguments</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fill_mask</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">fill_image</span><span class="p">)</span>
  322. <span class="k">class</span> <span class="nc">SegColorJitter</span><span class="p">(</span><span class="n">transforms</span><span class="o">.</span><span class="n">ColorJitter</span><span class="p">):</span>
  323. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">):</span>
  324. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">SegColorJitter</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__call__</span><span class="p">(</span><span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">])</span>
  325. <span class="k">return</span> <span class="n">sample</span>
  326. <span class="k">def</span> <span class="nf">_validate_fill_values_arguments</span><span class="p">(</span><span class="n">fill_mask</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">fill_image</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">List</span><span class="p">]):</span>
  327. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">fill_image</span><span class="p">,</span> <span class="n">collections</span><span class="o">.</span><span class="n">abc</span><span class="o">.</span><span class="n">Iterable</span><span class="p">):</span>
  328. <span class="c1"># If fill_image is single value, turn to grey color in RGB mode.</span>
  329. <span class="n">fill_image</span> <span class="o">=</span> <span class="p">(</span><span class="n">fill_image</span><span class="p">,</span> <span class="n">fill_image</span><span class="p">,</span> <span class="n">fill_image</span><span class="p">)</span>
  330. <span class="k">elif</span> <span class="nb">len</span><span class="p">(</span><span class="n">fill_image</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">3</span><span class="p">:</span>
  331. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;fill_image must be an RGB tuple of size equal to 3, found: </span><span class="si">{</span><span class="n">fill_image</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
  332. <span class="c1"># assert values are integers</span>
  333. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">fill_mask</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="ow">or</span> <span class="ow">not</span> <span class="nb">all</span><span class="p">(</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="nb">int</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">fill_image</span><span class="p">):</span>
  334. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Fill value must be integers,&quot;</span> <span class="sa">f</span><span class="s2">&quot; found: fill_image = </span><span class="si">{</span><span class="n">fill_image</span><span class="si">}</span><span class="s2">, fill_mask = </span><span class="si">{</span><span class="n">fill_mask</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
  335. <span class="c1"># assert values in range 0-255</span>
  336. <span class="k">if</span> <span class="nb">min</span><span class="p">(</span><span class="n">fill_image</span><span class="p">)</span> <span class="o">&lt;</span> <span class="mi">0</span> <span class="ow">or</span> <span class="nb">max</span><span class="p">(</span><span class="n">fill_image</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">255</span> <span class="ow">or</span> <span class="n">fill_mask</span> <span class="o">&lt;</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">fill_mask</span> <span class="o">&gt;</span> <span class="mi">255</span><span class="p">:</span>
  337. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Fill value must be a value from 0 to 255,&quot;</span> <span class="sa">f</span><span class="s2">&quot; found: fill_image = </span><span class="si">{</span><span class="n">fill_image</span><span class="si">}</span><span class="s2">, fill_mask = </span><span class="si">{</span><span class="n">fill_mask</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
  338. <span class="k">return</span> <span class="n">fill_mask</span><span class="p">,</span> <span class="n">fill_image</span>
  339. <span class="k">class</span> <span class="nc">DetectionTransform</span><span class="p">:</span>
  340. <span class="sd">&quot;&quot;&quot;</span>
  341. <span class="sd"> Detection transform base class.</span>
  342. <span class="sd"> Complex transforms that require extra data loading can use the the additional_samples_count attribute in a</span>
  343. <span class="sd"> similar fashion to what&#39;s been done in COCODetectionDataset:</span>
  344. <span class="sd"> self._load_additional_inputs_for_transform(sample, transform)</span>
  345. <span class="sd"> # after the above call, sample[&quot;additional_samples&quot;] holds a list of additional inputs and targets.</span>
  346. <span class="sd"> sample = transform(sample)</span>
  347. <span class="sd"> Attributes:</span>
  348. <span class="sd"> additional_samples_count: (int) additional samples to be loaded.</span>
  349. <span class="sd"> non_empty_targets: (bool) whether the additianl targets can have empty targets or not.</span>
  350. <span class="sd"> &quot;&quot;&quot;</span>
  351. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">additional_samples_count</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">non_empty_targets</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
  352. <span class="bp">self</span><span class="o">.</span><span class="n">additional_samples_count</span> <span class="o">=</span> <span class="n">additional_samples_count</span>
  353. <span class="bp">self</span><span class="o">.</span><span class="n">non_empty_targets</span> <span class="o">=</span> <span class="n">non_empty_targets</span>
  354. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">dict</span><span class="p">,</span> <span class="nb">list</span><span class="p">]):</span>
  355. <span class="k">raise</span> <span class="ne">NotImplementedError</span>
  356. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  357. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;{&quot;</span><span class="p">,</span> <span class="s2">&quot;(&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;}&quot;</span><span class="p">,</span> <span class="s2">&quot;)&quot;</span><span class="p">)</span>
  358. <div class="viewcode-block" id="DetectionMosaic"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.transforms.DetectionMosaic">[docs]</a><span class="k">class</span> <span class="nc">DetectionMosaic</span><span class="p">(</span><span class="n">DetectionTransform</span><span class="p">):</span>
  359. <span class="sd">&quot;&quot;&quot;</span>
  360. <span class="sd"> DetectionMosaic detection transform</span>
  361. <span class="sd"> Attributes:</span>
  362. <span class="sd"> input_dim: (tuple) input dimension.</span>
  363. <span class="sd"> prob: (float) probability of applying mosaic.</span>
  364. <span class="sd"> enable_mosaic: (bool) whether to apply mosaic at all (regardless of prob) (default=True).</span>
  365. <span class="sd"> &quot;&quot;&quot;</span>
  366. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_dim</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">,</span> <span class="n">prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span> <span class="n">enable_mosaic</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
  367. <span class="nb">super</span><span class="p">(</span><span class="n">DetectionMosaic</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">additional_samples_count</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
  368. <span class="bp">self</span><span class="o">.</span><span class="n">prob</span> <span class="o">=</span> <span class="n">prob</span>
  369. <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span> <span class="o">=</span> <span class="n">input_dim</span>
  370. <span class="bp">self</span><span class="o">.</span><span class="n">enable_mosaic</span> <span class="o">=</span> <span class="n">enable_mosaic</span>
  371. <div class="viewcode-block" id="DetectionMosaic.close"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.transforms.DetectionMosaic.close">[docs]</a> <span class="k">def</span> <span class="nf">close</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  372. <span class="bp">self</span><span class="o">.</span><span class="n">additional_samples_count</span> <span class="o">=</span> <span class="mi">0</span>
  373. <span class="bp">self</span><span class="o">.</span><span class="n">enable_mosaic</span> <span class="o">=</span> <span class="kc">False</span></div>
  374. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">dict</span><span class="p">,</span> <span class="nb">list</span><span class="p">]):</span>
  375. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">enable_mosaic</span> <span class="ow">and</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">prob</span><span class="p">:</span>
  376. <span class="n">mosaic_labels</span> <span class="o">=</span> <span class="p">[]</span>
  377. <span class="n">mosaic_labels_seg</span> <span class="o">=</span> <span class="p">[]</span>
  378. <span class="n">input_h</span><span class="p">,</span> <span class="n">input_w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
  379. <span class="c1"># yc, xc = s, s # mosaic center x, y</span>
  380. <span class="n">yc</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">input_h</span><span class="p">,</span> <span class="mf">1.5</span> <span class="o">*</span> <span class="n">input_h</span><span class="p">))</span>
  381. <span class="n">xc</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">input_w</span><span class="p">,</span> <span class="mf">1.5</span> <span class="o">*</span> <span class="n">input_w</span><span class="p">))</span>
  382. <span class="c1"># 3 additional samples, total of 4</span>
  383. <span class="n">all_samples</span> <span class="o">=</span> <span class="p">[</span><span class="n">sample</span><span class="p">]</span> <span class="o">+</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;additional_samples&quot;</span><span class="p">]</span>
  384. <span class="k">for</span> <span class="n">i_mosaic</span><span class="p">,</span> <span class="n">mosaic_sample</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">all_samples</span><span class="p">):</span>
  385. <span class="n">img</span><span class="p">,</span> <span class="n">_labels</span> <span class="o">=</span> <span class="n">mosaic_sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">mosaic_sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span>
  386. <span class="n">_labels_seg</span> <span class="o">=</span> <span class="n">mosaic_sample</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;target_seg&quot;</span><span class="p">)</span>
  387. <span class="n">h0</span><span class="p">,</span> <span class="n">w0</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span> <span class="c1"># orig hw</span>
  388. <span class="n">scale</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">*</span> <span class="n">input_h</span> <span class="o">/</span> <span class="n">h0</span><span class="p">,</span> <span class="mf">1.0</span> <span class="o">*</span> <span class="n">input_w</span> <span class="o">/</span> <span class="n">w0</span><span class="p">)</span>
  389. <span class="n">img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">w0</span> <span class="o">*</span> <span class="n">scale</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">h0</span> <span class="o">*</span> <span class="n">scale</span><span class="p">)),</span> <span class="n">interpolation</span><span class="o">=</span><span class="n">cv2</span><span class="o">.</span><span class="n">INTER_LINEAR</span><span class="p">)</span>
  390. <span class="c1"># generate output mosaic image</span>
  391. <span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">3</span><span class="p">]</span>
  392. <span class="k">if</span> <span class="n">i_mosaic</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  393. <span class="n">mosaic_img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">input_h</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">input_w</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">c</span><span class="p">),</span> <span class="mi">114</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
  394. <span class="c1"># suffix l means large image, while s means small image in mosaic aug.</span>
  395. <span class="p">(</span><span class="n">l_x1</span><span class="p">,</span> <span class="n">l_y1</span><span class="p">,</span> <span class="n">l_x2</span><span class="p">,</span> <span class="n">l_y2</span><span class="p">),</span> <span class="p">(</span><span class="n">s_x1</span><span class="p">,</span> <span class="n">s_y1</span><span class="p">,</span> <span class="n">s_x2</span><span class="p">,</span> <span class="n">s_y2</span><span class="p">)</span> <span class="o">=</span> <span class="n">get_mosaic_coordinate</span><span class="p">(</span><span class="n">i_mosaic</span><span class="p">,</span> <span class="n">xc</span><span class="p">,</span> <span class="n">yc</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">input_h</span><span class="p">,</span> <span class="n">input_w</span><span class="p">)</span>
  396. <span class="n">mosaic_img</span><span class="p">[</span><span class="n">l_y1</span><span class="p">:</span><span class="n">l_y2</span><span class="p">,</span> <span class="n">l_x1</span><span class="p">:</span><span class="n">l_x2</span><span class="p">]</span> <span class="o">=</span> <span class="n">img</span><span class="p">[</span><span class="n">s_y1</span><span class="p">:</span><span class="n">s_y2</span><span class="p">,</span> <span class="n">s_x1</span><span class="p">:</span><span class="n">s_x2</span><span class="p">]</span>
  397. <span class="n">padw</span><span class="p">,</span> <span class="n">padh</span> <span class="o">=</span> <span class="n">l_x1</span> <span class="o">-</span> <span class="n">s_x1</span><span class="p">,</span> <span class="n">l_y1</span> <span class="o">-</span> <span class="n">s_y1</span>
  398. <span class="n">labels</span> <span class="o">=</span> <span class="n">_labels</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
  399. <span class="c1"># Normalized xywh to pixel xyxy format</span>
  400. <span class="k">if</span> <span class="n">_labels</span><span class="o">.</span><span class="n">size</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
  401. <span class="n">labels</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">scale</span> <span class="o">*</span> <span class="n">_labels</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">padw</span>
  402. <span class="n">labels</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">scale</span> <span class="o">*</span> <span class="n">_labels</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">padh</span>
  403. <span class="n">labels</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">scale</span> <span class="o">*</span> <span class="n">_labels</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">padw</span>
  404. <span class="n">labels</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="n">scale</span> <span class="o">*</span> <span class="n">_labels</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">+</span> <span class="n">padh</span>
  405. <span class="n">mosaic_labels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span>
  406. <span class="k">if</span> <span class="n">_labels_seg</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  407. <span class="n">labels_seg</span> <span class="o">=</span> <span class="n">_labels_seg</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
  408. <span class="k">if</span> <span class="n">_labels</span><span class="o">.</span><span class="n">size</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
  409. <span class="n">labels_seg</span><span class="p">[:,</span> <span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">scale</span> <span class="o">*</span> <span class="n">labels_seg</span><span class="p">[:,</span> <span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">padw</span>
  410. <span class="n">labels_seg</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">scale</span> <span class="o">*</span> <span class="n">labels_seg</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">padh</span>
  411. <span class="n">mosaic_labels_seg</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">labels_seg</span><span class="p">)</span>
  412. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">mosaic_labels</span><span class="p">):</span>
  413. <span class="n">mosaic_labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span><span class="n">mosaic_labels</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
  414. <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">mosaic_labels</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">input_w</span><span class="p">,</span> <span class="n">out</span><span class="o">=</span><span class="n">mosaic_labels</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span>
  415. <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">mosaic_labels</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">input_h</span><span class="p">,</span> <span class="n">out</span><span class="o">=</span><span class="n">mosaic_labels</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span>
  416. <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">mosaic_labels</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">input_w</span><span class="p">,</span> <span class="n">out</span><span class="o">=</span><span class="n">mosaic_labels</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">])</span>
  417. <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">mosaic_labels</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">input_h</span><span class="p">,</span> <span class="n">out</span><span class="o">=</span><span class="n">mosaic_labels</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">])</span>
  418. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">mosaic_labels_seg</span><span class="p">):</span>
  419. <span class="n">mosaic_labels_seg</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span><span class="n">mosaic_labels_seg</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
  420. <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">mosaic_labels_seg</span><span class="p">[:,</span> <span class="p">::</span><span class="mi">2</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">input_w</span><span class="p">,</span> <span class="n">out</span><span class="o">=</span><span class="n">mosaic_labels_seg</span><span class="p">[:,</span> <span class="p">::</span><span class="mi">2</span><span class="p">])</span>
  421. <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">mosaic_labels_seg</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">input_h</span><span class="p">,</span> <span class="n">out</span><span class="o">=</span><span class="n">mosaic_labels_seg</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">])</span>
  422. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mosaic_img</span>
  423. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mosaic_labels</span>
  424. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;info&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">mosaic_img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">mosaic_img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
  425. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">mosaic_labels_seg</span><span class="p">):</span>
  426. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target_seg&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">mosaic_labels_seg</span>
  427. <span class="k">return</span> <span class="n">sample</span></div>
  428. <div class="viewcode-block" id="DetectionRandomAffine"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.transforms.DetectionRandomAffine">[docs]</a><span class="k">class</span> <span class="nc">DetectionRandomAffine</span><span class="p">(</span><span class="n">DetectionTransform</span><span class="p">):</span>
  429. <span class="sd">&quot;&quot;&quot;</span>
  430. <span class="sd"> DetectionRandomAffine detection transform</span>
  431. <span class="sd"> Attributes:</span>
  432. <span class="sd"> target_size: (tuple) desired output shape.</span>
  433. <span class="sd"> degrees: (Union[tuple, float]) degrees for random rotation, when float the random values are drawn uniformly</span>
  434. <span class="sd"> from (-degrees, degrees)</span>
  435. <span class="sd"> translate: (Union[tuple, float]) translate size (in pixels) for random translation, when float the random values</span>
  436. <span class="sd"> are drawn uniformly from (-translate, translate)</span>
  437. <span class="sd"> scales: (Union[tuple, float]) values for random rescale, when float the random values are drawn uniformly</span>
  438. <span class="sd"> from (0.1-scales, 0.1+scales)</span>
  439. <span class="sd"> shear: (Union[tuple, float]) degrees for random shear, when float the random values are drawn uniformly</span>
  440. <span class="sd"> from (shear, shear)</span>
  441. <span class="sd"> enable: (bool) whether to apply the below transform at all.</span>
  442. <span class="sd"> filter_box_candidates: (bool) whether to filter out transformed bboxes by edge size, area ratio, and aspect ratio (default=False).</span>
  443. <span class="sd"> wh_thr: (float) edge size threshold when filter_box_candidates = True. Bounding oxes with edges smaller</span>
  444. <span class="sd"> then this values will be filtered out. (default=2)</span>
  445. <span class="sd"> ar_thr: (float) aspect ratio threshold filter_box_candidates = True. Bounding boxes with aspect ratio larger</span>
  446. <span class="sd"> then this values will be filtered out. (default=20)</span>
  447. <span class="sd"> area_thr:(float) threshold for area ratio between original image and the transformed one, when when filter_box_candidates = True.</span>
  448. <span class="sd"> Bounding boxes with such ratio smaller then this value will be filtered out. (default=0.1)</span>
  449. <span class="sd"> &quot;&quot;&quot;</span>
  450. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
  451. <span class="bp">self</span><span class="p">,</span> <span class="n">degrees</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">translate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">scales</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">shear</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">target_size</span><span class="o">=</span><span class="p">(</span><span class="mi">640</span><span class="p">,</span> <span class="mi">640</span><span class="p">),</span> <span class="n">filter_box_candidates</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">wh_thr</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">ar_thr</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">area_thr</span><span class="o">=</span><span class="mf">0.1</span>
  452. <span class="p">):</span>
  453. <span class="nb">super</span><span class="p">(</span><span class="n">DetectionRandomAffine</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
  454. <span class="bp">self</span><span class="o">.</span><span class="n">degrees</span> <span class="o">=</span> <span class="n">degrees</span>
  455. <span class="bp">self</span><span class="o">.</span><span class="n">translate</span> <span class="o">=</span> <span class="n">translate</span>
  456. <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">scales</span>
  457. <span class="bp">self</span><span class="o">.</span><span class="n">shear</span> <span class="o">=</span> <span class="n">shear</span>
  458. <span class="bp">self</span><span class="o">.</span><span class="n">target_size</span> <span class="o">=</span> <span class="n">target_size</span>
  459. <span class="bp">self</span><span class="o">.</span><span class="n">enable</span> <span class="o">=</span> <span class="kc">True</span>
  460. <span class="bp">self</span><span class="o">.</span><span class="n">filter_box_candidates</span> <span class="o">=</span> <span class="n">filter_box_candidates</span>
  461. <span class="bp">self</span><span class="o">.</span><span class="n">wh_thr</span> <span class="o">=</span> <span class="n">wh_thr</span>
  462. <span class="bp">self</span><span class="o">.</span><span class="n">ar_thr</span> <span class="o">=</span> <span class="n">ar_thr</span>
  463. <span class="bp">self</span><span class="o">.</span><span class="n">area_thr</span> <span class="o">=</span> <span class="n">area_thr</span>
  464. <div class="viewcode-block" id="DetectionRandomAffine.close"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.transforms.DetectionRandomAffine.close">[docs]</a> <span class="k">def</span> <span class="nf">close</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  465. <span class="bp">self</span><span class="o">.</span><span class="n">enable</span> <span class="o">=</span> <span class="kc">False</span></div>
  466. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
  467. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">enable</span><span class="p">:</span>
  468. <span class="n">img</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">random_affine</span><span class="p">(</span>
  469. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span>
  470. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">],</span>
  471. <span class="n">sample</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;target_seg&quot;</span><span class="p">),</span>
  472. <span class="n">target_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">target_size</span><span class="p">,</span>
  473. <span class="n">degrees</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">degrees</span><span class="p">,</span>
  474. <span class="n">translate</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">translate</span><span class="p">,</span>
  475. <span class="n">scales</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">scale</span><span class="p">,</span>
  476. <span class="n">shear</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">shear</span><span class="p">,</span>
  477. <span class="n">filter_box_candidates</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">filter_box_candidates</span><span class="p">,</span>
  478. <span class="n">wh_thr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">wh_thr</span><span class="p">,</span>
  479. <span class="n">area_thr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">area_thr</span><span class="p">,</span>
  480. <span class="n">ar_thr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ar_thr</span><span class="p">,</span>
  481. <span class="p">)</span>
  482. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">img</span>
  483. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">target</span>
  484. <span class="k">return</span> <span class="n">sample</span></div>
  485. <span class="k">class</span> <span class="nc">DetectionMixup</span><span class="p">(</span><span class="n">DetectionTransform</span><span class="p">):</span>
  486. <span class="sd">&quot;&quot;&quot;</span>
  487. <span class="sd"> Mixup detection transform</span>
  488. <span class="sd"> Attributes:</span>
  489. <span class="sd"> input_dim: (tuple) input dimension.</span>
  490. <span class="sd"> mixup_scale: (tuple) scale range for the additional loaded image for mixup.</span>
  491. <span class="sd"> prob: (float) probability of applying mixup.</span>
  492. <span class="sd"> enable_mixup: (bool) whether to apply mixup at all (regardless of prob) (default=True).</span>
  493. <span class="sd"> flip_prob: (float) prbability to apply horizontal flip to the additional sample.</span>
  494. <span class="sd"> &quot;&quot;&quot;</span>
  495. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_dim</span><span class="p">,</span> <span class="n">mixup_scale</span><span class="p">,</span> <span class="n">prob</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">enable_mixup</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">flip_prob</span><span class="o">=</span><span class="mf">0.5</span><span class="p">):</span>
  496. <span class="nb">super</span><span class="p">(</span><span class="n">DetectionMixup</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">additional_samples_count</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">non_empty_targets</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  497. <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span> <span class="o">=</span> <span class="n">input_dim</span>
  498. <span class="bp">self</span><span class="o">.</span><span class="n">mixup_scale</span> <span class="o">=</span> <span class="n">mixup_scale</span>
  499. <span class="bp">self</span><span class="o">.</span><span class="n">prob</span> <span class="o">=</span> <span class="n">prob</span>
  500. <span class="bp">self</span><span class="o">.</span><span class="n">enable_mixup</span> <span class="o">=</span> <span class="n">enable_mixup</span>
  501. <span class="bp">self</span><span class="o">.</span><span class="n">flip_prob</span> <span class="o">=</span> <span class="n">flip_prob</span>
  502. <span class="k">def</span> <span class="nf">close</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  503. <span class="bp">self</span><span class="o">.</span><span class="n">additional_samples_count</span> <span class="o">=</span> <span class="mi">0</span>
  504. <span class="bp">self</span><span class="o">.</span><span class="n">enable_mixup</span> <span class="o">=</span> <span class="kc">False</span>
  505. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
  506. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">enable_mixup</span> <span class="ow">and</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">prob</span><span class="p">:</span>
  507. <span class="n">origin_img</span><span class="p">,</span> <span class="n">origin_labels</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span>
  508. <span class="n">cp_sample</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;additional_samples&quot;</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span>
  509. <span class="n">img</span><span class="p">,</span> <span class="n">cp_labels</span> <span class="o">=</span> <span class="n">cp_sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">cp_sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span>
  510. <span class="n">cp_boxes</span> <span class="o">=</span> <span class="n">cp_labels</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span>
  511. <span class="n">img</span><span class="p">,</span> <span class="n">cp_boxes</span> <span class="o">=</span> <span class="n">_mirror</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">cp_boxes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">flip_prob</span><span class="p">)</span>
  512. <span class="c1"># PLUG IN TARGET THE FLIPPED BOXES</span>
  513. <span class="n">cp_labels</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="n">cp_boxes</span>
  514. <span class="n">jit_factor</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">mixup_scale</span><span class="p">)</span>
  515. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
  516. <span class="n">cp_img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="mi">3</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span> <span class="o">*</span> <span class="mi">114</span>
  517. <span class="k">else</span><span class="p">:</span>
  518. <span class="n">cp_img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span> <span class="o">*</span> <span class="mi">114</span>
  519. <span class="n">cp_scale_ratio</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
  520. <span class="n">resized_img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span>
  521. <span class="n">img</span><span class="p">,</span>
  522. <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">cp_scale_ratio</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">cp_scale_ratio</span><span class="p">)),</span>
  523. <span class="n">interpolation</span><span class="o">=</span><span class="n">cv2</span><span class="o">.</span><span class="n">INTER_LINEAR</span><span class="p">,</span>
  524. <span class="p">)</span>
  525. <span class="n">cp_img</span><span class="p">[:</span> <span class="nb">int</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">cp_scale_ratio</span><span class="p">),</span> <span class="p">:</span> <span class="nb">int</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">cp_scale_ratio</span><span class="p">)]</span> <span class="o">=</span> <span class="n">resized_img</span>
  526. <span class="n">cp_img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span>
  527. <span class="n">cp_img</span><span class="p">,</span>
  528. <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">cp_img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">jit_factor</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">cp_img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">jit_factor</span><span class="p">)),</span>
  529. <span class="p">)</span>
  530. <span class="n">cp_scale_ratio</span> <span class="o">*=</span> <span class="n">jit_factor</span>
  531. <span class="n">origin_h</span><span class="p">,</span> <span class="n">origin_w</span> <span class="o">=</span> <span class="n">cp_img</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span>
  532. <span class="n">target_h</span><span class="p">,</span> <span class="n">target_w</span> <span class="o">=</span> <span class="n">origin_img</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span>
  533. <span class="n">padded_img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="nb">max</span><span class="p">(</span><span class="n">origin_h</span><span class="p">,</span> <span class="n">target_h</span><span class="p">),</span> <span class="nb">max</span><span class="p">(</span><span class="n">origin_w</span><span class="p">,</span> <span class="n">target_w</span><span class="p">),</span> <span class="mi">3</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
  534. <span class="n">padded_img</span><span class="p">[:</span><span class="n">origin_h</span><span class="p">,</span> <span class="p">:</span><span class="n">origin_w</span><span class="p">]</span> <span class="o">=</span> <span class="n">cp_img</span>
  535. <span class="n">x_offset</span><span class="p">,</span> <span class="n">y_offset</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span>
  536. <span class="k">if</span> <span class="n">padded_img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">target_h</span><span class="p">:</span>
  537. <span class="n">y_offset</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">padded_img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">target_h</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
  538. <span class="k">if</span> <span class="n">padded_img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">target_w</span><span class="p">:</span>
  539. <span class="n">x_offset</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">padded_img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">target_w</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
  540. <span class="n">padded_cropped_img</span> <span class="o">=</span> <span class="n">padded_img</span><span class="p">[</span><span class="n">y_offset</span> <span class="p">:</span> <span class="n">y_offset</span> <span class="o">+</span> <span class="n">target_h</span><span class="p">,</span> <span class="n">x_offset</span> <span class="p">:</span> <span class="n">x_offset</span> <span class="o">+</span> <span class="n">target_w</span><span class="p">]</span>
  541. <span class="n">cp_bboxes_origin_np</span> <span class="o">=</span> <span class="n">adjust_box_anns</span><span class="p">(</span><span class="n">cp_labels</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span><span class="o">.</span><span class="n">copy</span><span class="p">(),</span> <span class="n">cp_scale_ratio</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">origin_w</span><span class="p">,</span> <span class="n">origin_h</span><span class="p">)</span>
  542. <span class="n">cp_bboxes_transformed_np</span> <span class="o">=</span> <span class="n">cp_bboxes_origin_np</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
  543. <span class="n">cp_bboxes_transformed_np</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">cp_bboxes_transformed_np</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">x_offset</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">target_w</span><span class="p">)</span>
  544. <span class="n">cp_bboxes_transformed_np</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">cp_bboxes_transformed_np</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">y_offset</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">target_h</span><span class="p">)</span>
  545. <span class="n">cls_labels</span> <span class="o">=</span> <span class="n">cp_labels</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">:</span><span class="mi">5</span><span class="p">]</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
  546. <span class="n">box_labels</span> <span class="o">=</span> <span class="n">cp_bboxes_transformed_np</span>
  547. <span class="n">labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">((</span><span class="n">box_labels</span><span class="p">,</span> <span class="n">cls_labels</span><span class="p">))</span>
  548. <span class="n">origin_labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">vstack</span><span class="p">((</span><span class="n">origin_labels</span><span class="p">,</span> <span class="n">labels</span><span class="p">))</span>
  549. <span class="n">origin_img</span> <span class="o">=</span> <span class="n">origin_img</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
  550. <span class="n">origin_img</span> <span class="o">=</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">origin_img</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">padded_cropped_img</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
  551. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">origin_img</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">),</span> <span class="n">origin_labels</span>
  552. <span class="k">return</span> <span class="n">sample</span>
  553. <div class="viewcode-block" id="DetectionPaddedRescale"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.transforms.DetectionPaddedRescale">[docs]</a><span class="k">class</span> <span class="nc">DetectionPaddedRescale</span><span class="p">(</span><span class="n">DetectionTransform</span><span class="p">):</span>
  554. <span class="sd">&quot;&quot;&quot;</span>
  555. <span class="sd"> Preprocessing transform to be applied last of all transforms for validation.</span>
  556. <span class="sd"> Image- Rescales and pads to self.input_dim.</span>
  557. <span class="sd"> Targets- pads targets to max_targets, moves the class label to first index, converts boxes format- xyxy -&gt; cxcywh.</span>
  558. <span class="sd"> Attributes:</span>
  559. <span class="sd"> input_dim: (tuple) final input dimension (default=(640,640))</span>
  560. <span class="sd"> swap: image axis&#39;s to be rearranged.</span>
  561. <span class="sd"> &quot;&quot;&quot;</span>
  562. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_dim</span><span class="p">,</span> <span class="n">swap</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">max_targets</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">pad_value</span><span class="o">=</span><span class="mi">114</span><span class="p">):</span>
  563. <span class="bp">self</span><span class="o">.</span><span class="n">swap</span> <span class="o">=</span> <span class="n">swap</span>
  564. <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span> <span class="o">=</span> <span class="n">input_dim</span>
  565. <span class="bp">self</span><span class="o">.</span><span class="n">max_targets</span> <span class="o">=</span> <span class="n">max_targets</span>
  566. <span class="bp">self</span><span class="o">.</span><span class="n">pad_value</span> <span class="o">=</span> <span class="n">pad_value</span>
  567. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">]):</span>
  568. <span class="n">img</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">crowd_targets</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;crowd_target&quot;</span><span class="p">)</span>
  569. <span class="n">img</span><span class="p">,</span> <span class="n">r</span> <span class="o">=</span> <span class="n">rescale_and_pad_to_size</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">swap</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">pad_value</span><span class="p">)</span>
  570. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">img</span>
  571. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_rescale_target</span><span class="p">(</span><span class="n">targets</span><span class="p">,</span> <span class="n">r</span><span class="p">)</span>
  572. <span class="k">if</span> <span class="n">crowd_targets</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  573. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;crowd_target&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_rescale_target</span><span class="p">(</span><span class="n">crowd_targets</span><span class="p">,</span> <span class="n">r</span><span class="p">)</span>
  574. <span class="k">return</span> <span class="n">sample</span>
  575. <span class="k">def</span> <span class="nf">_rescale_target</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">targets</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">r</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">:</span>
  576. <span class="sd">&quot;&quot;&quot;SegRescale the target according to a coefficient used to rescale the image.</span>
  577. <span class="sd"> This is done to have images and targets at the same scale.</span>
  578. <span class="sd"> :param targets: Targets to rescale, shape (batch_size, 6)</span>
  579. <span class="sd"> :param r: SegRescale coefficient that was applied to the image</span>
  580. <span class="sd"> :return: Rescaled targets, shape (batch_size, 6)</span>
  581. <span class="sd"> &quot;&quot;&quot;</span>
  582. <span class="n">targets</span> <span class="o">=</span> <span class="n">targets</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">targets</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">max_targets</span><span class="p">,</span> <span class="mi">5</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
  583. <span class="n">boxes</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">],</span> <span class="n">targets</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">]</span>
  584. <span class="n">boxes</span> <span class="o">=</span> <span class="n">xyxy2cxcywh</span><span class="p">(</span><span class="n">boxes</span><span class="p">)</span>
  585. <span class="n">boxes</span> <span class="o">*=</span> <span class="n">r</span>
  586. <span class="n">boxes</span> <span class="o">=</span> <span class="n">cxcywh2xyxy</span><span class="p">(</span><span class="n">boxes</span><span class="p">)</span>
  587. <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">boxes</span><span class="p">,</span> <span class="n">labels</span><span class="p">[:,</span> <span class="n">np</span><span class="o">.</span><span class="n">newaxis</span><span class="p">]),</span> <span class="mi">1</span><span class="p">)</span></div>
  588. <span class="k">class</span> <span class="nc">DetectionHorizontalFlip</span><span class="p">(</span><span class="n">DetectionTransform</span><span class="p">):</span>
  589. <span class="sd">&quot;&quot;&quot;</span>
  590. <span class="sd"> Horizontal Flip for Detection</span>
  591. <span class="sd"> Attributes:</span>
  592. <span class="sd"> prob: float: probability of applying horizontal flip</span>
  593. <span class="sd"> max_targets: int: max objects in single image, padding target to this size in case of empty image.</span>
  594. <span class="sd"> &quot;&quot;&quot;</span>
  595. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prob</span><span class="p">,</span> <span class="n">max_targets</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">120</span><span class="p">):</span>
  596. <span class="nb">super</span><span class="p">(</span><span class="n">DetectionHorizontalFlip</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
  597. <span class="bp">self</span><span class="o">.</span><span class="n">prob</span> <span class="o">=</span> <span class="n">prob</span>
  598. <span class="bp">self</span><span class="o">.</span><span class="n">max_targets</span> <span class="o">=</span> <span class="n">max_targets</span>
  599. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">):</span>
  600. <span class="n">image</span><span class="p">,</span> <span class="n">targets</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span>
  601. <span class="n">boxes</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span>
  602. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">boxes</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  603. <span class="n">targets</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">max_targets</span><span class="p">,</span> <span class="mi">5</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
  604. <span class="n">boxes</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span>
  605. <span class="n">image</span><span class="p">,</span> <span class="n">boxes</span> <span class="o">=</span> <span class="n">_mirror</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">boxes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">prob</span><span class="p">)</span>
  606. <span class="n">targets</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="n">boxes</span>
  607. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">targets</span>
  608. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">image</span>
  609. <span class="k">return</span> <span class="n">sample</span>
  610. <div class="viewcode-block" id="DetectionHSV"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.transforms.DetectionHSV">[docs]</a><span class="k">class</span> <span class="nc">DetectionHSV</span><span class="p">(</span><span class="n">DetectionTransform</span><span class="p">):</span>
  611. <span class="sd">&quot;&quot;&quot;</span>
  612. <span class="sd"> Detection HSV transform.</span>
  613. <span class="sd"> &quot;&quot;&quot;</span>
  614. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">prob</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">hgain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span> <span class="n">sgain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span> <span class="n">vgain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span> <span class="n">bgr_channels</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)):</span>
  615. <span class="nb">super</span><span class="p">(</span><span class="n">DetectionHSV</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
  616. <span class="bp">self</span><span class="o">.</span><span class="n">prob</span> <span class="o">=</span> <span class="n">prob</span>
  617. <span class="bp">self</span><span class="o">.</span><span class="n">hgain</span> <span class="o">=</span> <span class="n">hgain</span>
  618. <span class="bp">self</span><span class="o">.</span><span class="n">sgain</span> <span class="o">=</span> <span class="n">sgain</span>
  619. <span class="bp">self</span><span class="o">.</span><span class="n">vgain</span> <span class="o">=</span> <span class="n">vgain</span>
  620. <span class="bp">self</span><span class="o">.</span><span class="n">bgr_channels</span> <span class="o">=</span> <span class="n">bgr_channels</span>
  621. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">dict</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span class="p">:</span>
  622. <span class="k">if</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">prob</span><span class="p">:</span>
  623. <span class="n">augment_hsv</span><span class="p">(</span><span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">hgain</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sgain</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vgain</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">bgr_channels</span><span class="p">)</span>
  624. <span class="k">return</span> <span class="n">sample</span></div>
  625. <div class="viewcode-block" id="DetectionTargetsFormatTransform"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.transforms.DetectionTargetsFormatTransform">[docs]</a><span class="k">class</span> <span class="nc">DetectionTargetsFormatTransform</span><span class="p">(</span><span class="n">DetectionTransform</span><span class="p">):</span>
  626. <span class="sd">&quot;&quot;&quot;</span>
  627. <span class="sd"> Detection targets format transform</span>
  628. <span class="sd"> Converts targets in input_format to output_format.</span>
  629. <span class="sd"> Attributes:</span>
  630. <span class="sd"> input_format: DetectionTargetsFormat: input target format</span>
  631. <span class="sd"> output_format: DetectionTargetsFormat: output target format</span>
  632. <span class="sd"> min_bbox_edge_size: int: bboxes with edge size lower then this values will be removed.</span>
  633. <span class="sd"> max_targets: int: max objects in single image, padding target to this size.</span>
  634. <span class="sd"> &quot;&quot;&quot;</span>
  635. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
  636. <span class="bp">self</span><span class="p">,</span>
  637. <span class="n">input_format</span><span class="p">:</span> <span class="n">DetectionTargetsFormat</span> <span class="o">=</span> <span class="n">DetectionTargetsFormat</span><span class="o">.</span><span class="n">XYXY_LABEL</span><span class="p">,</span>
  638. <span class="n">output_format</span><span class="p">:</span> <span class="n">DetectionTargetsFormat</span> <span class="o">=</span> <span class="n">DetectionTargetsFormat</span><span class="o">.</span><span class="n">LABEL_CXCYWH</span><span class="p">,</span>
  639. <span class="n">min_bbox_edge_size</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
  640. <span class="n">max_targets</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">120</span><span class="p">,</span>
  641. <span class="p">):</span>
  642. <span class="nb">super</span><span class="p">(</span><span class="n">DetectionTargetsFormatTransform</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
  643. <span class="bp">self</span><span class="o">.</span><span class="n">input_format</span> <span class="o">=</span> <span class="n">input_format</span>
  644. <span class="bp">self</span><span class="o">.</span><span class="n">output_format</span> <span class="o">=</span> <span class="n">output_format</span>
  645. <span class="bp">self</span><span class="o">.</span><span class="n">min_bbox_edge_size</span> <span class="o">=</span> <span class="n">min_bbox_edge_size</span>
  646. <span class="bp">self</span><span class="o">.</span><span class="n">max_targets</span> <span class="o">=</span> <span class="n">max_targets</span>
  647. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sample</span><span class="p">):</span>
  648. <span class="n">normalized_input</span> <span class="o">=</span> <span class="s2">&quot;NORMALIZED&quot;</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_format</span><span class="o">.</span><span class="n">value</span>
  649. <span class="n">normalized_output</span> <span class="o">=</span> <span class="s2">&quot;NORMALIZED&quot;</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_format</span><span class="o">.</span><span class="n">value</span>
  650. <span class="n">normalize</span> <span class="o">=</span> <span class="ow">not</span> <span class="n">normalized_input</span> <span class="ow">and</span> <span class="n">normalized_output</span>
  651. <span class="n">denormalize</span> <span class="o">=</span> <span class="n">normalized_input</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">normalized_output</span>
  652. <span class="n">label_first_in_input</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_format</span><span class="o">.</span><span class="n">value</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;_&quot;</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;LABEL&quot;</span>
  653. <span class="n">label_first_in_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_format</span><span class="o">.</span><span class="n">value</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;_&quot;</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;LABEL&quot;</span>
  654. <span class="n">input_xyxy_format</span> <span class="o">=</span> <span class="s2">&quot;XYXY&quot;</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_format</span><span class="o">.</span><span class="n">value</span>
  655. <span class="n">output_xyxy_format</span> <span class="o">=</span> <span class="s2">&quot;XYXY&quot;</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_format</span><span class="o">.</span><span class="n">value</span>
  656. <span class="n">convert2xyxy</span> <span class="o">=</span> <span class="ow">not</span> <span class="n">input_xyxy_format</span> <span class="ow">and</span> <span class="n">output_xyxy_format</span>
  657. <span class="n">convert2cxcy</span> <span class="o">=</span> <span class="n">input_xyxy_format</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">output_xyxy_format</span>
  658. <span class="n">image</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">crowd_targets</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;crowd_target&quot;</span><span class="p">)</span>
  659. <span class="n">_</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">shape</span>
  660. <span class="k">def</span> <span class="nf">_format_target</span><span class="p">(</span><span class="n">targets_in</span><span class="p">):</span>
  661. <span class="k">if</span> <span class="n">label_first_in_input</span><span class="p">:</span>
  662. <span class="n">labels</span><span class="p">,</span> <span class="n">boxes</span> <span class="o">=</span> <span class="n">targets_in</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">targets_in</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">:]</span>
  663. <span class="k">else</span><span class="p">:</span>
  664. <span class="n">boxes</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">targets_in</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">],</span> <span class="n">targets_in</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">]</span>
  665. <span class="k">if</span> <span class="n">convert2cxcy</span><span class="p">:</span>
  666. <span class="n">boxes</span> <span class="o">=</span> <span class="n">xyxy2cxcywh</span><span class="p">(</span><span class="n">boxes</span><span class="p">)</span>
  667. <span class="k">elif</span> <span class="n">convert2xyxy</span><span class="p">:</span>
  668. <span class="n">boxes</span> <span class="o">=</span> <span class="n">cxcywh2xyxy</span><span class="p">(</span><span class="n">boxes</span><span class="p">)</span>
  669. <span class="k">if</span> <span class="n">normalize</span><span class="p">:</span>
  670. <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">w</span>
  671. <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">h</span>
  672. <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">/</span> <span class="n">w</span>
  673. <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">/</span> <span class="n">h</span>
  674. <span class="k">elif</span> <span class="n">denormalize</span><span class="p">:</span>
  675. <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">w</span>
  676. <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">h</span>
  677. <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="n">w</span>
  678. <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">*</span> <span class="n">h</span>
  679. <span class="n">min_bbox_edge_size</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_bbox_edge_size</span> <span class="o">/</span> <span class="nb">max</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span> <span class="k">if</span> <span class="n">normalized_output</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_bbox_edge_size</span>
  680. <span class="n">cxcywh_boxes</span> <span class="o">=</span> <span class="n">boxes</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">output_xyxy_format</span> <span class="k">else</span> <span class="n">xyxy2cxcywh</span><span class="p">(</span><span class="n">boxes</span><span class="o">.</span><span class="n">copy</span><span class="p">())</span>
  681. <span class="n">mask_b</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">minimum</span><span class="p">(</span><span class="n">cxcywh_boxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">],</span> <span class="n">cxcywh_boxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">])</span> <span class="o">&gt;</span> <span class="n">min_bbox_edge_size</span>
  682. <span class="n">boxes_t</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[</span><span class="n">mask_b</span><span class="p">]</span>
  683. <span class="n">labels_t</span> <span class="o">=</span> <span class="n">labels</span><span class="p">[</span><span class="n">mask_b</span><span class="p">]</span>
  684. <span class="n">labels_t</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">labels_t</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
  685. <span class="n">targets_t</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">((</span><span class="n">labels_t</span><span class="p">,</span> <span class="n">boxes_t</span><span class="p">))</span> <span class="k">if</span> <span class="n">label_first_in_output</span> <span class="k">else</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">((</span><span class="n">boxes_t</span><span class="p">,</span> <span class="n">labels_t</span><span class="p">))</span>
  686. <span class="n">padded_targets</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">max_targets</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
  687. <span class="n">padded_targets</span><span class="p">[</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">targets_t</span><span class="p">))[:</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_targets</span><span class="p">]]</span> <span class="o">=</span> <span class="n">targets_t</span><span class="p">[:</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_targets</span><span class="p">]</span>
  688. <span class="n">padded_targets</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ascontiguousarray</span><span class="p">(</span><span class="n">padded_targets</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
  689. <span class="k">return</span> <span class="n">padded_targets</span>
  690. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">_format_target</span><span class="p">(</span><span class="n">targets</span><span class="p">)</span>
  691. <span class="k">if</span> <span class="n">crowd_targets</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  692. <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;crowd_target&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">_format_target</span><span class="p">(</span><span class="n">crowd_targets</span><span class="p">)</span>
  693. <span class="k">return</span> <span class="n">sample</span></div>
  694. <span class="k">def</span> <span class="nf">get_aug_params</span><span class="p">(</span><span class="n">value</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">tuple</span><span class="p">,</span> <span class="nb">float</span><span class="p">],</span> <span class="n">center</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">0</span><span class="p">):</span>
  695. <span class="sd">&quot;&quot;&quot;</span>
  696. <span class="sd"> Generates a random value for augmentations as described below</span>
  697. <span class="sd"> :param value: Union[tuple, float] defines the range of values for generation. Wen tuple-</span>
  698. <span class="sd"> drawn uniformly between (value[0], value[1]), and (center - value, center + value) when float</span>
  699. <span class="sd"> :param center: float, defines center to subtract when value is float.</span>
  700. <span class="sd"> :return: generated value</span>
  701. <span class="sd"> &quot;&quot;&quot;</span>
  702. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="nb">float</span><span class="p">):</span>
  703. <span class="k">return</span> <span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">center</span> <span class="o">-</span> <span class="n">value</span><span class="p">,</span> <span class="n">center</span> <span class="o">+</span> <span class="n">value</span><span class="p">)</span>
  704. <span class="k">elif</span> <span class="nb">len</span><span class="p">(</span><span class="n">value</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
  705. <span class="k">return</span> <span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">value</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">value</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
  706. <span class="k">else</span><span class="p">:</span>
  707. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
  708. <span class="s2">&quot;Affine params should be either a sequence containing two values</span><span class="se">\</span>
  709. <span class="s2"> or single float values. Got </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span>
  710. <span class="n">value</span>
  711. <span class="p">)</span>
  712. <span class="p">)</span>
  713. <span class="k">def</span> <span class="nf">get_affine_matrix</span><span class="p">(</span>
  714. <span class="n">target_size</span><span class="p">,</span>
  715. <span class="n">degrees</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
  716. <span class="n">translate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span>
  717. <span class="n">scales</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span>
  718. <span class="n">shear</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
  719. <span class="p">):</span>
  720. <span class="sd">&quot;&quot;&quot;</span>
  721. <span class="sd"> Returns a random affine transform matrix.</span>
  722. <span class="sd"> :param target_size: (tuple) desired output shape.</span>
  723. <span class="sd"> :param degrees: (Union[tuple, float]) degrees for random rotation, when float the random values are drawn uniformly</span>
  724. <span class="sd"> from (-degrees, degrees)</span>
  725. <span class="sd"> :param translate: (Union[tuple, float]) translate size (in pixels) for random translation, when float the random values</span>
  726. <span class="sd"> are drawn uniformly from (-translate, translate)</span>
  727. <span class="sd"> :param scales: (Union[tuple, float]) values for random rescale, when float the random values are drawn uniformly</span>
  728. <span class="sd"> from (0.1-scales, 0.1+scales)</span>
  729. <span class="sd"> :param shear: (Union[tuple, float]) degrees for random shear, when float the random values are drawn uniformly</span>
  730. <span class="sd"> from (shear, shear)</span>
  731. <span class="sd"> :return: affine_transform_matrix, drawn_scale</span>
  732. <span class="sd"> &quot;&quot;&quot;</span>
  733. <span class="n">twidth</span><span class="p">,</span> <span class="n">theight</span> <span class="o">=</span> <span class="n">target_size</span>
  734. <span class="c1"># Rotation and Scale</span>
  735. <span class="n">angle</span> <span class="o">=</span> <span class="n">get_aug_params</span><span class="p">(</span><span class="n">degrees</span><span class="p">)</span>
  736. <span class="n">scale</span> <span class="o">=</span> <span class="n">get_aug_params</span><span class="p">(</span><span class="n">scales</span><span class="p">,</span> <span class="n">center</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
  737. <span class="k">if</span> <span class="n">scale</span> <span class="o">&lt;=</span> <span class="mf">0.0</span><span class="p">:</span>
  738. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Argument scale should be positive&quot;</span><span class="p">)</span>
  739. <span class="n">R</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">getRotationMatrix2D</span><span class="p">(</span><span class="n">angle</span><span class="o">=</span><span class="n">angle</span><span class="p">,</span> <span class="n">center</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="n">scale</span><span class="o">=</span><span class="n">scale</span><span class="p">)</span>
  740. <span class="n">M</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span>
  741. <span class="c1"># Shear</span>
  742. <span class="n">shear_x</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">tan</span><span class="p">(</span><span class="n">get_aug_params</span><span class="p">(</span><span class="n">shear</span><span class="p">)</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span> <span class="o">/</span> <span class="mi">180</span><span class="p">)</span>
  743. <span class="n">shear_y</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">tan</span><span class="p">(</span><span class="n">get_aug_params</span><span class="p">(</span><span class="n">shear</span><span class="p">)</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span> <span class="o">/</span> <span class="mi">180</span><span class="p">)</span>
  744. <span class="n">M</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">R</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">shear_y</span> <span class="o">*</span> <span class="n">R</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
  745. <span class="n">M</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">R</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">shear_x</span> <span class="o">*</span> <span class="n">R</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
  746. <span class="c1"># Translation</span>
  747. <span class="n">translation_x</span> <span class="o">=</span> <span class="n">get_aug_params</span><span class="p">(</span><span class="n">translate</span><span class="p">)</span> <span class="o">*</span> <span class="n">twidth</span> <span class="c1"># x translation (pixels)</span>
  748. <span class="n">translation_y</span> <span class="o">=</span> <span class="n">get_aug_params</span><span class="p">(</span><span class="n">translate</span><span class="p">)</span> <span class="o">*</span> <span class="n">theight</span> <span class="c1"># y translation (pixels)</span>
  749. <span class="n">M</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">translation_x</span>
  750. <span class="n">M</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">translation_y</span>
  751. <span class="k">return</span> <span class="n">M</span><span class="p">,</span> <span class="n">scale</span>
  752. <span class="k">def</span> <span class="nf">apply_affine_to_bboxes</span><span class="p">(</span><span class="n">targets</span><span class="p">,</span> <span class="n">targets_seg</span><span class="p">,</span> <span class="n">target_size</span><span class="p">,</span> <span class="n">M</span><span class="p">):</span>
  753. <span class="n">num_gts</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">targets</span><span class="p">)</span>
  754. <span class="n">twidth</span><span class="p">,</span> <span class="n">theight</span> <span class="o">=</span> <span class="n">target_size</span>
  755. <span class="c1"># targets_seg = [B x w x h]</span>
  756. <span class="c1"># if any is_not_nan in axis = 1</span>
  757. <span class="n">seg_is_present_mask</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">logical_or</span><span class="o">.</span><span class="n">reduce</span><span class="p">(</span><span class="o">~</span><span class="n">np</span><span class="o">.</span><span class="n">isnan</span><span class="p">(</span><span class="n">targets_seg</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
  758. <span class="n">num_gts_masks</span> <span class="o">=</span> <span class="n">seg_is_present_mask</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
  759. <span class="n">num_gts_boxes</span> <span class="o">=</span> <span class="n">num_gts</span> <span class="o">-</span> <span class="n">num_gts_masks</span>
  760. <span class="k">if</span> <span class="n">num_gts_boxes</span><span class="p">:</span>
  761. <span class="c1"># warp corner points</span>
  762. <span class="n">corner_points</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">num_gts_boxes</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
  763. <span class="c1"># x1y1, x2y2, x1y2, x2y1</span>
  764. <span class="n">corner_points</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="o">~</span><span class="n">seg_is_present_mask</span><span class="p">][:,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">]]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">num_gts_boxes</span> <span class="o">*</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
  765. <span class="n">corner_points</span> <span class="o">=</span> <span class="n">corner_points</span> <span class="o">@</span> <span class="n">M</span><span class="o">.</span><span class="n">T</span> <span class="c1"># apply affine transform</span>
  766. <span class="n">corner_points</span> <span class="o">=</span> <span class="n">corner_points</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">num_gts_boxes</span><span class="p">,</span> <span class="mi">8</span><span class="p">)</span>
  767. <span class="c1"># create new boxes</span>
  768. <span class="n">corner_xs</span> <span class="o">=</span> <span class="n">corner_points</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span>
  769. <span class="n">corner_ys</span> <span class="o">=</span> <span class="n">corner_points</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span>
  770. <span class="n">new_bboxes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">corner_xs</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">corner_ys</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">corner_xs</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">corner_ys</span><span class="p">,</span> <span class="mi">1</span><span class="p">)))</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">T</span>
  771. <span class="k">else</span><span class="p">:</span>
  772. <span class="n">new_bboxes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float</span><span class="p">)</span>
  773. <span class="k">if</span> <span class="n">num_gts_masks</span><span class="p">:</span>
  774. <span class="c1"># warp segmentation points</span>
  775. <span class="n">num_seg_points</span> <span class="o">=</span> <span class="n">targets_seg</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">//</span> <span class="mi">2</span>
  776. <span class="n">corner_points_seg</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">num_gts_masks</span> <span class="o">*</span> <span class="n">num_seg_points</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span>
  777. <span class="n">corner_points_seg</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">targets_seg</span><span class="p">[</span><span class="n">seg_is_present_mask</span><span class="p">]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">num_gts_masks</span> <span class="o">*</span> <span class="n">num_seg_points</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
  778. <span class="n">corner_points_seg</span> <span class="o">=</span> <span class="n">corner_points_seg</span> <span class="o">@</span> <span class="n">M</span><span class="o">.</span><span class="n">T</span>
  779. <span class="n">corner_points_seg</span> <span class="o">=</span> <span class="n">corner_points_seg</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">num_gts_masks</span><span class="p">,</span> <span class="n">num_seg_points</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span>
  780. <span class="c1"># create new boxes</span>
  781. <span class="n">seg_points_xs</span> <span class="o">=</span> <span class="n">corner_points_seg</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span>
  782. <span class="n">seg_points_ys</span> <span class="o">=</span> <span class="n">corner_points_seg</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span>
  783. <span class="n">new_tight_bboxes</span> <span class="o">=</span> <span class="p">(</span>
  784. <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">np</span><span class="o">.</span><span class="n">nanmin</span><span class="p">(</span><span class="n">seg_points_xs</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">nanmin</span><span class="p">(</span><span class="n">seg_points_ys</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">nanmax</span><span class="p">(</span><span class="n">seg_points_xs</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">nanmax</span><span class="p">(</span><span class="n">seg_points_ys</span><span class="p">,</span> <span class="mi">1</span><span class="p">)))</span>
  785. <span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
  786. <span class="o">.</span><span class="n">T</span>
  787. <span class="p">)</span>
  788. <span class="k">else</span><span class="p">:</span>
  789. <span class="n">new_tight_bboxes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float</span><span class="p">)</span>
  790. <span class="n">targets</span><span class="p">[</span><span class="o">~</span><span class="n">seg_is_present_mask</span><span class="p">,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="n">new_bboxes</span>
  791. <span class="n">targets</span><span class="p">[</span><span class="n">seg_is_present_mask</span><span class="p">,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="n">new_tight_bboxes</span>
  792. <span class="c1"># clip boxes</span>
  793. <span class="n">targets</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">]]</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">]]</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">twidth</span><span class="p">)</span>
  794. <span class="n">targets</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">]]</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">]]</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">theight</span><span class="p">)</span>
  795. <span class="k">return</span> <span class="n">targets</span>
  796. <span class="k">def</span> <span class="nf">random_affine</span><span class="p">(</span>
  797. <span class="n">img</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
  798. <span class="n">targets</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span> <span class="o">=</span> <span class="p">(),</span>
  799. <span class="n">targets_seg</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
  800. <span class="n">target_size</span><span class="p">:</span> <span class="nb">tuple</span> <span class="o">=</span> <span class="p">(</span><span class="mi">640</span><span class="p">,</span> <span class="mi">640</span><span class="p">),</span>
  801. <span class="n">degrees</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">]</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span>
  802. <span class="n">translate</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">]</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
  803. <span class="n">scales</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">]</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
  804. <span class="n">shear</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">]</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span>
  805. <span class="n">filter_box_candidates</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
  806. <span class="n">wh_thr</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span>
  807. <span class="n">ar_thr</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span>
  808. <span class="n">area_thr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span>
  809. <span class="p">):</span>
  810. <span class="sd">&quot;&quot;&quot;</span>
  811. <span class="sd"> Performs random affine transform to img, targets</span>
  812. <span class="sd"> :param img: Input image</span>
  813. <span class="sd"> :param targets: Input target</span>
  814. <span class="sd"> :param targets_seg: Targets derived from segmentation masks</span>
  815. <span class="sd"> :param target_size: Desired output shape</span>
  816. <span class="sd"> :param degrees: Degrees for random rotation, when float the random values are drawn uniformly</span>
  817. <span class="sd"> from (-degrees, degrees).</span>
  818. <span class="sd"> :param translate: Translate size (in pixels) for random translation, when float the random values</span>
  819. <span class="sd"> are drawn uniformly from (-translate, translate)</span>
  820. <span class="sd"> :param scales: Values for random rescale, when float the random values are drawn uniformly</span>
  821. <span class="sd"> from (0.1-scales, 0.1+scales)</span>
  822. <span class="sd"> :param shear: Degrees for random shear, when float the random values are drawn uniformly</span>
  823. <span class="sd"> from (shear, shear)</span>
  824. <span class="sd"> :param filter_box_candidates: whether to filter out transformed bboxes by edge size, area ratio, and aspect ratio.</span>
  825. <span class="sd"> :param wh_thr: (float) edge size threshold when filter_box_candidates = True. Bounding oxes with edges smaller</span>
  826. <span class="sd"> then this values will be filtered out. (default=2)</span>
  827. <span class="sd"> :param ar_thr: (float) aspect ratio threshold filter_box_candidates = True. Bounding boxes with aspect ratio larger</span>
  828. <span class="sd"> then this values will be filtered out. (default=20)</span>
  829. <span class="sd"> :param area_thr:(float) threshold for area ratio between original image and the transformed one, when when filter_box_candidates = True.</span>
  830. <span class="sd"> Bounding boxes with such ratio smaller then this value will be filtered out. (default=0.1)</span>
  831. <span class="sd"> :return: Image and Target with applied random affine</span>
  832. <span class="sd"> &quot;&quot;&quot;</span>
  833. <span class="n">targets_seg</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">targets</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">0</span><span class="p">))</span> <span class="k">if</span> <span class="n">targets_seg</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">targets_seg</span>
  834. <span class="n">M</span><span class="p">,</span> <span class="n">scale</span> <span class="o">=</span> <span class="n">get_affine_matrix</span><span class="p">(</span><span class="n">target_size</span><span class="p">,</span> <span class="n">degrees</span><span class="p">,</span> <span class="n">translate</span><span class="p">,</span> <span class="n">scales</span><span class="p">,</span> <span class="n">shear</span><span class="p">)</span>
  835. <span class="n">img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">warpAffine</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">M</span><span class="p">,</span> <span class="n">dsize</span><span class="o">=</span><span class="n">target_size</span><span class="p">,</span> <span class="n">borderValue</span><span class="o">=</span><span class="p">(</span><span class="mi">114</span><span class="p">,</span> <span class="mi">114</span><span class="p">,</span> <span class="mi">114</span><span class="p">))</span>
  836. <span class="c1"># Transform label coordinates</span>
  837. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">targets</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
  838. <span class="n">targets_orig</span> <span class="o">=</span> <span class="n">targets</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
  839. <span class="n">targets</span> <span class="o">=</span> <span class="n">apply_affine_to_bboxes</span><span class="p">(</span><span class="n">targets</span><span class="p">,</span> <span class="n">targets_seg</span><span class="p">,</span> <span class="n">target_size</span><span class="p">,</span> <span class="n">M</span><span class="p">)</span>
  840. <span class="k">if</span> <span class="n">filter_box_candidates</span><span class="p">:</span>
  841. <span class="n">box_candidates_ids</span> <span class="o">=</span> <span class="n">_filter_box_candidates</span><span class="p">(</span><span class="n">targets_orig</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">],</span> <span class="n">targets</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">],</span> <span class="n">wh_thr</span><span class="o">=</span><span class="n">wh_thr</span><span class="p">,</span> <span class="n">ar_thr</span><span class="o">=</span><span class="n">ar_thr</span><span class="p">,</span> <span class="n">area_thr</span><span class="o">=</span><span class="n">area_thr</span><span class="p">)</span>
  842. <span class="n">targets</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="n">box_candidates_ids</span><span class="p">]</span>
  843. <span class="k">return</span> <span class="n">img</span><span class="p">,</span> <span class="n">targets</span>
  844. <span class="k">def</span> <span class="nf">_filter_box_candidates</span><span class="p">(</span><span class="n">box1</span><span class="p">,</span> <span class="n">box2</span><span class="p">,</span> <span class="n">wh_thr</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">ar_thr</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">area_thr</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span>
  845. <span class="sd">&quot;&quot;&quot;</span>
  846. <span class="sd"> compute candidate boxes</span>
  847. <span class="sd"> :param box1: before augment</span>
  848. <span class="sd"> :param box2: after augment</span>
  849. <span class="sd"> :param wh_thr: wh_thr (pixels)</span>
  850. <span class="sd"> :param ar_thr: aspect_ratio_thr</span>
  851. <span class="sd"> :param area_thr: area_ratio</span>
  852. <span class="sd"> :return:</span>
  853. <span class="sd"> &quot;&quot;&quot;</span>
  854. <span class="n">box1</span> <span class="o">=</span> <span class="n">box1</span><span class="o">.</span><span class="n">T</span>
  855. <span class="n">box2</span> <span class="o">=</span> <span class="n">box2</span><span class="o">.</span><span class="n">T</span>
  856. <span class="n">w1</span><span class="p">,</span> <span class="n">h1</span> <span class="o">=</span> <span class="n">box1</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">box1</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">box1</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">-</span> <span class="n">box1</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
  857. <span class="n">w2</span><span class="p">,</span> <span class="n">h2</span> <span class="o">=</span> <span class="n">box2</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">box2</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">box2</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">-</span> <span class="n">box2</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
  858. <span class="n">ar</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">w2</span> <span class="o">/</span> <span class="p">(</span><span class="n">h2</span> <span class="o">+</span> <span class="mf">1e-16</span><span class="p">),</span> <span class="n">h2</span> <span class="o">/</span> <span class="p">(</span><span class="n">w2</span> <span class="o">+</span> <span class="mf">1e-16</span><span class="p">))</span> <span class="c1"># aspect ratio</span>
  859. <span class="k">return</span> <span class="p">(</span><span class="n">w2</span> <span class="o">&gt;</span> <span class="n">wh_thr</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">h2</span> <span class="o">&gt;</span> <span class="n">wh_thr</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">w2</span> <span class="o">*</span> <span class="n">h2</span> <span class="o">/</span> <span class="p">(</span><span class="n">w1</span> <span class="o">*</span> <span class="n">h1</span> <span class="o">+</span> <span class="mf">1e-16</span><span class="p">)</span> <span class="o">&gt;</span> <span class="n">area_thr</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">ar</span> <span class="o">&lt;</span> <span class="n">ar_thr</span><span class="p">)</span> <span class="c1"># candidates</span>
  860. <span class="k">def</span> <span class="nf">_mirror</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">boxes</span><span class="p">,</span> <span class="n">prob</span><span class="o">=</span><span class="mf">0.5</span><span class="p">):</span>
  861. <span class="sd">&quot;&quot;&quot;</span>
  862. <span class="sd"> Horizontal flips image and bboxes with probability prob.</span>
  863. <span class="sd"> :param image: (np.array) image to be flipped.</span>
  864. <span class="sd"> :param boxes: (np.array) bboxes to be modified.</span>
  865. <span class="sd"> :param prob: probability to perform flipping.</span>
  866. <span class="sd"> :return: flipped_image, flipped_bboxes</span>
  867. <span class="sd"> &quot;&quot;&quot;</span>
  868. <span class="n">flipped_boxes</span> <span class="o">=</span> <span class="n">boxes</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
  869. <span class="n">_</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">shape</span>
  870. <span class="k">if</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o">&lt;</span> <span class="n">prob</span><span class="p">:</span>
  871. <span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="p">[:,</span> <span class="p">::</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
  872. <span class="n">flipped_boxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">width</span> <span class="o">-</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">::</span><span class="o">-</span><span class="mi">2</span><span class="p">]</span>
  873. <span class="k">return</span> <span class="n">image</span><span class="p">,</span> <span class="n">flipped_boxes</span>
  874. <span class="k">def</span> <span class="nf">augment_hsv</span><span class="p">(</span><span class="n">img</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">,</span> <span class="n">hgain</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">sgain</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">vgain</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">bgr_channels</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)):</span>
  875. <span class="n">hsv_augs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="o">*</span> <span class="p">[</span><span class="n">hgain</span><span class="p">,</span> <span class="n">sgain</span><span class="p">,</span> <span class="n">vgain</span><span class="p">]</span> <span class="c1"># random gains</span>
  876. <span class="n">hsv_augs</span> <span class="o">*=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="c1"># random selection of h, s, v</span>
  877. <span class="n">hsv_augs</span> <span class="o">=</span> <span class="n">hsv_augs</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int16</span><span class="p">)</span>
  878. <span class="n">img_hsv</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">cvtColor</span><span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">bgr_channels</span><span class="p">],</span> <span class="n">cv2</span><span class="o">.</span><span class="n">COLOR_BGR2HSV</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int16</span><span class="p">)</span>
  879. <span class="n">img_hsv</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">img_hsv</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">hsv_augs</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">%</span> <span class="mi">180</span>
  880. <span class="n">img_hsv</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">img_hsv</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">hsv_augs</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">255</span><span class="p">)</span>
  881. <span class="n">img_hsv</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">img_hsv</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">hsv_augs</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">255</span><span class="p">)</span>
  882. <span class="n">img</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">bgr_channels</span><span class="p">]</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">cvtColor</span><span class="p">(</span><span class="n">img_hsv</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">dtype</span><span class="p">),</span> <span class="n">cv2</span><span class="o">.</span><span class="n">COLOR_HSV2BGR</span><span class="p">)</span> <span class="c1"># no return needed</span>
  883. <span class="k">def</span> <span class="nf">rescale_and_pad_to_size</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">input_size</span><span class="p">,</span> <span class="n">swap</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">pad_val</span><span class="o">=</span><span class="mi">114</span><span class="p">):</span>
  884. <span class="sd">&quot;&quot;&quot;</span>
  885. <span class="sd"> Rescales image according to minimum ratio between the target height /image height, target width / image width,</span>
  886. <span class="sd"> and pads the image to the target size.</span>
  887. <span class="sd"> :param img: Image to be rescaled</span>
  888. <span class="sd"> :param input_size: Target size</span>
  889. <span class="sd"> :param swap: Axis&#39;s to be rearranged.</span>
  890. <span class="sd"> :return: rescaled image, ratio</span>
  891. <span class="sd"> &quot;&quot;&quot;</span>
  892. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
  893. <span class="n">padded_img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">input_size</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">input_size</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span> <span class="o">*</span> <span class="n">pad_val</span>
  894. <span class="k">else</span><span class="p">:</span>
  895. <span class="n">padded_img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">input_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span> <span class="o">*</span> <span class="n">pad_val</span>
  896. <span class="n">r</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">input_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">input_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
  897. <span class="n">resized_img</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span>
  898. <span class="n">img</span><span class="p">,</span>
  899. <span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">r</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">r</span><span class="p">)),</span>
  900. <span class="n">interpolation</span><span class="o">=</span><span class="n">cv2</span><span class="o">.</span><span class="n">INTER_LINEAR</span><span class="p">,</span>
  901. <span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
  902. <span class="n">padded_img</span><span class="p">[:</span> <span class="nb">int</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">r</span><span class="p">),</span> <span class="p">:</span> <span class="nb">int</span><span class="p">(</span><span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">r</span><span class="p">)]</span> <span class="o">=</span> <span class="n">resized_img</span>
  903. <span class="n">padded_img</span> <span class="o">=</span> <span class="n">padded_img</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">swap</span><span class="p">)</span>
  904. <span class="n">padded_img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ascontiguousarray</span><span class="p">(</span><span class="n">padded_img</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
  905. <span class="k">return</span> <span class="n">padded_img</span><span class="p">,</span> <span class="n">r</span>
  906. </pre></div>
  907. </div>
  908. </div>
  909. <footer>
  910. <hr/>
  911. <div role="contentinfo">
  912. <p>&#169; Copyright 2021, SuperGradients team.</p>
  913. </div>
  914. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  915. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  916. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  917. </footer>
  918. </div>
  919. </div>
  920. </section>
  921. </div>
  922. <script>
  923. jQuery(function () {
  924. SphinxRtdTheme.Navigation.enable(true);
  925. });
  926. </script>
  927. </body>
  928. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.utils.callbacks &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.utils.callbacks</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.utils.callbacks</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">copy</span>
  84. <span class="kn">import</span> <span class="nn">os</span>
  85. <span class="kn">import</span> <span class="nn">time</span>
  86. <span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
  87. <span class="kn">import</span> <span class="nn">math</span>
  88. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  89. <span class="kn">import</span> <span class="nn">onnx</span>
  90. <span class="kn">import</span> <span class="nn">onnxruntime</span>
  91. <span class="kn">import</span> <span class="nn">torch</span>
  92. <span class="kn">import</span> <span class="nn">signal</span>
  93. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span>
  94. <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
  95. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">DetectionVisualization</span><span class="p">,</span> <span class="n">DetectionPostPredictionCallback</span>
  96. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.segmentation_utils</span> <span class="kn">import</span> <span class="n">BinarySegmentationVisualization</span>
  97. <span class="kn">import</span> <span class="nn">cv2</span>
  98. <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
  99. <span class="k">try</span><span class="p">:</span>
  100. <span class="kn">from</span> <span class="nn">deci_lab_client.client</span> <span class="kn">import</span> <span class="n">DeciPlatformClient</span>
  101. <span class="kn">from</span> <span class="nn">deci_lab_client.models</span> <span class="kn">import</span> <span class="n">ModelBenchmarkState</span>
  102. <span class="kn">from</span> <span class="nn">deci_lab_client.models.model_metadata</span> <span class="kn">import</span> <span class="n">ModelMetadata</span>
  103. <span class="n">_imported_deci_lab_failure</span> <span class="o">=</span> <span class="kc">None</span>
  104. <span class="k">except</span> <span class="p">(</span><span class="ne">ImportError</span><span class="p">,</span> <span class="ne">NameError</span><span class="p">,</span> <span class="ne">ModuleNotFoundError</span><span class="p">)</span> <span class="k">as</span> <span class="n">import_err</span><span class="p">:</span>
  105. <span class="n">logger</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">&quot;Failed to import deci_lab_client&quot;</span><span class="p">)</span>
  106. <span class="n">_imported_deci_lab_failure</span> <span class="o">=</span> <span class="n">import_err</span>
  107. <div class="viewcode-block" id="Phase"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.Phase">[docs]</a><span class="k">class</span> <span class="nc">Phase</span><span class="p">(</span><span class="n">Enum</span><span class="p">):</span>
  108. <span class="n">PRE_TRAINING</span> <span class="o">=</span> <span class="s2">&quot;PRE_TRAINING&quot;</span>
  109. <span class="n">TRAIN_BATCH_END</span> <span class="o">=</span> <span class="s2">&quot;TRAIN_BATCH_END&quot;</span>
  110. <span class="n">TRAIN_BATCH_STEP</span> <span class="o">=</span> <span class="s2">&quot;TRAIN_BATCH_STEP&quot;</span>
  111. <span class="n">TRAIN_EPOCH_START</span> <span class="o">=</span> <span class="s2">&quot;TRAIN_EPOCH_START&quot;</span>
  112. <span class="n">TRAIN_EPOCH_END</span> <span class="o">=</span> <span class="s2">&quot;TRAIN_EPOCH_END&quot;</span>
  113. <span class="n">VALIDATION_BATCH_END</span> <span class="o">=</span> <span class="s2">&quot;VALIDATION_BATCH_END&quot;</span>
  114. <span class="n">VALIDATION_EPOCH_END</span> <span class="o">=</span> <span class="s2">&quot;VALIDATION_EPOCH_END&quot;</span>
  115. <span class="n">VALIDATION_END_BEST_EPOCH</span> <span class="o">=</span> <span class="s2">&quot;VALIDATION_END_BEST_EPOCH&quot;</span>
  116. <span class="n">TEST_BATCH_END</span> <span class="o">=</span> <span class="s2">&quot;TEST_BATCH_END&quot;</span>
  117. <span class="n">TEST_END</span> <span class="o">=</span> <span class="s2">&quot;TEST_END&quot;</span>
  118. <span class="n">POST_TRAINING</span> <span class="o">=</span> <span class="s2">&quot;POST_TRAINING&quot;</span></div>
  119. <div class="viewcode-block" id="ContextSgMethods"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.ContextSgMethods">[docs]</a><span class="k">class</span> <span class="nc">ContextSgMethods</span><span class="p">:</span>
  120. <span class="sd">&quot;&quot;&quot;</span>
  121. <span class="sd"> Class for delegating SgModel&#39;s methods, so that only the relevant ones are (&quot;phase wise&quot;) are accessible.</span>
  122. <span class="sd"> &quot;&quot;&quot;</span>
  123. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">methods</span><span class="p">):</span>
  124. <span class="k">for</span> <span class="n">attr</span><span class="p">,</span> <span class="n">attr_val</span> <span class="ow">in</span> <span class="n">methods</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
  125. <span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">attr</span><span class="p">,</span> <span class="n">attr_val</span><span class="p">)</span></div>
  126. <div class="viewcode-block" id="PhaseContext"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.PhaseContext">[docs]</a><span class="k">class</span> <span class="nc">PhaseContext</span><span class="p">:</span>
  127. <span class="sd">&quot;&quot;&quot;</span>
  128. <span class="sd"> Represents the input for phase callbacks, and is constantly updated after callback calls.</span>
  129. <span class="sd"> &quot;&quot;&quot;</span>
  130. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">batch_idx</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">metrics_dict</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">inputs</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">preds</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
  131. <span class="n">target</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">metrics_compute_fn</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">loss_avg_meter</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">loss_log_items</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">criterion</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
  132. <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">experiment_name</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">ckpt_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">net</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">lr_warmup_epochs</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">sg_logger</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
  133. <span class="n">train_loader</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">valid_loader</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
  134. <span class="n">training_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">ddp_silent_mode</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">checkpoint_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">architecture</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
  135. <span class="n">arch_params</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">metric_idx_in_results_tuple</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
  136. <span class="n">metric_to_watch</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">valid_metrics</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">context_methods</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  137. <span class="bp">self</span><span class="o">.</span><span class="n">epoch</span> <span class="o">=</span> <span class="n">epoch</span>
  138. <span class="bp">self</span><span class="o">.</span><span class="n">batch_idx</span> <span class="o">=</span> <span class="n">batch_idx</span>
  139. <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer</span>
  140. <span class="bp">self</span><span class="o">.</span><span class="n">inputs</span> <span class="o">=</span> <span class="n">inputs</span>
  141. <span class="bp">self</span><span class="o">.</span><span class="n">preds</span> <span class="o">=</span> <span class="n">preds</span>
  142. <span class="bp">self</span><span class="o">.</span><span class="n">target</span> <span class="o">=</span> <span class="n">target</span>
  143. <span class="bp">self</span><span class="o">.</span><span class="n">metrics_dict</span> <span class="o">=</span> <span class="n">metrics_dict</span>
  144. <span class="bp">self</span><span class="o">.</span><span class="n">metrics_compute_fn</span> <span class="o">=</span> <span class="n">metrics_compute_fn</span>
  145. <span class="bp">self</span><span class="o">.</span><span class="n">loss_avg_meter</span> <span class="o">=</span> <span class="n">loss_avg_meter</span>
  146. <span class="bp">self</span><span class="o">.</span><span class="n">loss_log_items</span> <span class="o">=</span> <span class="n">loss_log_items</span>
  147. <span class="bp">self</span><span class="o">.</span><span class="n">criterion</span> <span class="o">=</span> <span class="n">criterion</span>
  148. <span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span>
  149. <span class="bp">self</span><span class="o">.</span><span class="n">stop_training</span> <span class="o">=</span> <span class="kc">False</span>
  150. <span class="bp">self</span><span class="o">.</span><span class="n">experiment_name</span> <span class="o">=</span> <span class="n">experiment_name</span>
  151. <span class="bp">self</span><span class="o">.</span><span class="n">ckpt_dir</span> <span class="o">=</span> <span class="n">ckpt_dir</span>
  152. <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">net</span>
  153. <span class="bp">self</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">=</span> <span class="n">lr_warmup_epochs</span>
  154. <span class="bp">self</span><span class="o">.</span><span class="n">sg_logger</span> <span class="o">=</span> <span class="n">sg_logger</span>
  155. <span class="bp">self</span><span class="o">.</span><span class="n">train_loader</span> <span class="o">=</span> <span class="n">train_loader</span>
  156. <span class="bp">self</span><span class="o">.</span><span class="n">valid_loader</span> <span class="o">=</span> <span class="n">valid_loader</span>
  157. <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span> <span class="o">=</span> <span class="n">training_params</span>
  158. <span class="bp">self</span><span class="o">.</span><span class="n">ddp_silent_mode</span> <span class="o">=</span> <span class="n">ddp_silent_mode</span>
  159. <span class="bp">self</span><span class="o">.</span><span class="n">checkpoint_params</span> <span class="o">=</span> <span class="n">checkpoint_params</span>
  160. <span class="bp">self</span><span class="o">.</span><span class="n">architecture</span> <span class="o">=</span> <span class="n">architecture</span>
  161. <span class="bp">self</span><span class="o">.</span><span class="n">arch_params</span> <span class="o">=</span> <span class="n">arch_params</span>
  162. <span class="bp">self</span><span class="o">.</span><span class="n">metric_idx_in_results_tuple</span> <span class="o">=</span> <span class="n">metric_idx_in_results_tuple</span>
  163. <span class="bp">self</span><span class="o">.</span><span class="n">metric_to_watch</span> <span class="o">=</span> <span class="n">metric_to_watch</span>
  164. <span class="bp">self</span><span class="o">.</span><span class="n">valid_metrics</span> <span class="o">=</span> <span class="n">valid_metrics</span>
  165. <span class="bp">self</span><span class="o">.</span><span class="n">context_methods</span> <span class="o">=</span> <span class="n">context_methods</span>
  166. <div class="viewcode-block" id="PhaseContext.update_context"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.PhaseContext.update_context">[docs]</a> <span class="k">def</span> <span class="nf">update_context</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  167. <span class="k">for</span> <span class="n">attr</span><span class="p">,</span> <span class="n">attr_val</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
  168. <span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">attr</span><span class="p">,</span> <span class="n">attr_val</span><span class="p">)</span></div></div>
  169. <div class="viewcode-block" id="PhaseCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.PhaseCallback">[docs]</a><span class="k">class</span> <span class="nc">PhaseCallback</span><span class="p">:</span>
  170. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">phase</span><span class="p">:</span> <span class="n">Phase</span><span class="p">):</span>
  171. <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">=</span> <span class="n">phase</span>
  172. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  173. <span class="k">raise</span> <span class="ne">NotImplementedError</span>
  174. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  175. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span></div>
  176. <div class="viewcode-block" id="ModelConversionCheckCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.ModelConversionCheckCallback">[docs]</a><span class="k">class</span> <span class="nc">ModelConversionCheckCallback</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
  177. <span class="sd">&quot;&quot;&quot;</span>
  178. <span class="sd"> Pre-training callback that verifies model conversion to onnx given specified conversion parameters.</span>
  179. <span class="sd"> The model is converted, then inference is applied with onnx runtime.</span>
  180. <span class="sd"> Use this callback wit hthe same args as DeciPlatformCallback to prevent conversion fails at the end of training.</span>
  181. <span class="sd"> Attributes:</span>
  182. <span class="sd"> model_meta_data: (ModelMetadata) model&#39;s meta-data object.</span>
  183. <span class="sd"> The following parameters may be passed as kwargs in order to control the conversion to onnx:</span>
  184. <span class="sd"> :param opset_version (default=11)</span>
  185. <span class="sd"> :param do_constant_folding (default=True)</span>
  186. <span class="sd"> :param dynamic_axes (default=</span>
  187. <span class="sd"> {&#39;input&#39;: {0: &#39;batch_size&#39;},</span>
  188. <span class="sd"> # Variable length axes</span>
  189. <span class="sd"> &#39;output&#39;: {0: &#39;batch_size&#39;}}</span>
  190. <span class="sd"> )</span>
  191. <span class="sd"> :param input_names (default=[&quot;input&quot;])</span>
  192. <span class="sd"> :param output_names (default=[&quot;output&quot;])</span>
  193. <span class="sd"> :param rtol (default=1e-03)</span>
  194. <span class="sd"> :param atol (default=1e-05)</span>
  195. <span class="sd"> &quot;&quot;&quot;</span>
  196. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model_meta_data</span><span class="p">:</span> <span class="n">ModelMetadata</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  197. <span class="nb">super</span><span class="p">(</span><span class="n">ModelConversionCheckCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="o">=</span><span class="n">Phase</span><span class="o">.</span><span class="n">PRE_TRAINING</span><span class="p">)</span>
  198. <span class="bp">self</span><span class="o">.</span><span class="n">model_meta_data</span> <span class="o">=</span> <span class="n">model_meta_data</span>
  199. <span class="bp">self</span><span class="o">.</span><span class="n">opset_version</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;opset_version&quot;</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
  200. <span class="bp">self</span><span class="o">.</span><span class="n">do_constant_folding</span> <span class="o">=</span> <span class="p">(</span>
  201. <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;do_constant_folding&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span> <span class="k">if</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;do_constant_folding&quot;</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span> <span class="k">else</span> <span class="kc">True</span>
  202. <span class="p">)</span>
  203. <span class="bp">self</span><span class="o">.</span><span class="n">input_names</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;input_names&quot;</span><span class="p">)</span> <span class="ow">or</span> <span class="p">[</span><span class="s2">&quot;input&quot;</span><span class="p">]</span>
  204. <span class="bp">self</span><span class="o">.</span><span class="n">output_names</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;output_names&quot;</span><span class="p">)</span> <span class="ow">or</span> <span class="p">[</span><span class="s2">&quot;output&quot;</span><span class="p">]</span>
  205. <span class="bp">self</span><span class="o">.</span><span class="n">dynamic_axes</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;dynamic_axes&quot;</span><span class="p">)</span> <span class="ow">or</span> <span class="p">{</span><span class="s2">&quot;input&quot;</span><span class="p">:</span> <span class="p">{</span><span class="mi">0</span><span class="p">:</span> <span class="s2">&quot;batch_size&quot;</span><span class="p">},</span> <span class="s2">&quot;output&quot;</span><span class="p">:</span> <span class="p">{</span><span class="mi">0</span><span class="p">:</span> <span class="s2">&quot;batch_size&quot;</span><span class="p">}}</span>
  206. <span class="bp">self</span><span class="o">.</span><span class="n">rtol</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;rtol&quot;</span><span class="p">,</span> <span class="mf">1e-03</span><span class="p">)</span>
  207. <span class="bp">self</span><span class="o">.</span><span class="n">atol</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">&quot;atol&quot;</span><span class="p">,</span> <span class="mf">1e-05</span><span class="p">)</span>
  208. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
  209. <span class="n">model</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="p">)</span>
  210. <span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
  211. <span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span> <span class="c1"># Put model into eval mode</span>
  212. <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="s2">&quot;prep_model_for_conversion&quot;</span><span class="p">):</span>
  213. <span class="n">model</span><span class="o">.</span><span class="n">prep_model_for_conversion</span><span class="p">(</span><span class="n">input_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">model_meta_data</span><span class="o">.</span><span class="n">input_dimensions</span><span class="p">)</span>
  214. <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span>
  215. <span class="bp">self</span><span class="o">.</span><span class="n">model_meta_data</span><span class="o">.</span><span class="n">primary_batch_size</span><span class="p">,</span> <span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">model_meta_data</span><span class="o">.</span><span class="n">input_dimensions</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span>
  216. <span class="p">)</span>
  217. <span class="n">tmp_model_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">ckpt_dir</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_meta_data</span><span class="o">.</span><span class="n">name</span> <span class="o">+</span> <span class="s2">&quot;_tmp.onnx&quot;</span><span class="p">)</span>
  218. <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
  219. <span class="n">torch_out</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
  220. <span class="n">torch</span><span class="o">.</span><span class="n">onnx</span><span class="o">.</span><span class="n">export</span><span class="p">(</span>
  221. <span class="n">model</span><span class="p">,</span> <span class="c1"># Model being run</span>
  222. <span class="n">x</span><span class="p">,</span> <span class="c1"># Model input (or a tuple for multiple inputs)</span>
  223. <span class="n">tmp_model_path</span><span class="p">,</span> <span class="c1"># Where to save the model (can be a file or file-like object)</span>
  224. <span class="n">export_params</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="c1"># Store the trained parameter weights inside the model file</span>
  225. <span class="n">opset_version</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">opset_version</span><span class="p">,</span>
  226. <span class="n">do_constant_folding</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">do_constant_folding</span><span class="p">,</span>
  227. <span class="n">input_names</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">input_names</span><span class="p">,</span>
  228. <span class="n">output_names</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">output_names</span><span class="p">,</span>
  229. <span class="n">dynamic_axes</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dynamic_axes</span><span class="p">,</span>
  230. <span class="p">)</span>
  231. <span class="n">onnx_model</span> <span class="o">=</span> <span class="n">onnx</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">tmp_model_path</span><span class="p">)</span>
  232. <span class="n">onnx</span><span class="o">.</span><span class="n">checker</span><span class="o">.</span><span class="n">check_model</span><span class="p">(</span><span class="n">onnx_model</span><span class="p">)</span>
  233. <span class="n">ort_session</span> <span class="o">=</span> <span class="n">onnxruntime</span><span class="o">.</span><span class="n">InferenceSession</span><span class="p">(</span>
  234. <span class="n">tmp_model_path</span><span class="p">,</span> <span class="n">providers</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;CUDAExecutionProvider&quot;</span><span class="p">,</span> <span class="s2">&quot;CPUExecutionProvider&quot;</span><span class="p">]</span>
  235. <span class="p">)</span>
  236. <span class="c1"># compute ONNX Runtime output prediction</span>
  237. <span class="n">ort_inputs</span> <span class="o">=</span> <span class="p">{</span><span class="n">ort_session</span><span class="o">.</span><span class="n">get_inputs</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">x</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()}</span>
  238. <span class="n">ort_outs</span> <span class="o">=</span> <span class="n">ort_session</span><span class="o">.</span><span class="n">run</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="n">ort_inputs</span><span class="p">)</span>
  239. <span class="c1"># TODO: Ideally we don&#39;t want to check this but have the certainty of just calling torch_out.cpu()</span>
  240. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">torch_out</span><span class="p">,</span> <span class="n">List</span><span class="p">)</span> <span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">torch_out</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>
  241. <span class="n">torch_out</span> <span class="o">=</span> <span class="n">torch_out</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
  242. <span class="c1"># compare ONNX Runtime and PyTorch results</span>
  243. <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span><span class="n">torch_out</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">ort_outs</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">rtol</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">rtol</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">atol</span><span class="p">)</span>
  244. <span class="n">os</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="n">tmp_model_path</span><span class="p">)</span>
  245. <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Exported model has been tested with ONNXRuntime, and the result looks good!&quot;</span><span class="p">)</span></div>
  246. <div class="viewcode-block" id="DeciLabUploadCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.DeciLabUploadCallback">[docs]</a><span class="k">class</span> <span class="nc">DeciLabUploadCallback</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
  247. <span class="sd">&quot;&quot;&quot;</span>
  248. <span class="sd"> Post-training callback for uploading and optimizing a model.</span>
  249. <span class="sd"> Attributes:</span>
  250. <span class="sd"> email: (str) username for Deci platform.</span>
  251. <span class="sd"> model_meta_data: (ModelMetadata) model&#39;s meta-data object.</span>
  252. <span class="sd"> optimization_request_form: (dict) optimization request form object.</span>
  253. <span class="sd"> password: (str) default=None, should only be used for testing.</span>
  254. <span class="sd"> ckpt_name: (str) default=&quot;ckpt_best&quot; refers to the filename of the checkpoint, inside the checkpoint directory.</span>
  255. <span class="sd"> The following parameters may be passed as kwargs in order to control the conversion to onnx:</span>
  256. <span class="sd"> :param opset_version</span>
  257. <span class="sd"> :param do_constant_folding</span>
  258. <span class="sd"> :param dynamic_axes</span>
  259. <span class="sd"> :param input_names</span>
  260. <span class="sd"> :param output_names</span>
  261. <span class="sd"> &quot;&quot;&quot;</span>
  262. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model_meta_data</span><span class="p">,</span> <span class="n">optimization_request_form</span><span class="p">,</span> <span class="n">auth_token</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">ckpt_name</span><span class="o">=</span><span class="s2">&quot;ckpt_best.pth&quot;</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  263. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="o">=</span><span class="n">Phase</span><span class="o">.</span><span class="n">POST_TRAINING</span><span class="p">)</span>
  264. <span class="k">if</span> <span class="n">_imported_deci_lab_failure</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  265. <span class="k">raise</span> <span class="n">_imported_deci_lab_failure</span>
  266. <span class="bp">self</span><span class="o">.</span><span class="n">model_meta_data</span> <span class="o">=</span> <span class="n">model_meta_data</span>
  267. <span class="bp">self</span><span class="o">.</span><span class="n">optimization_request_form</span> <span class="o">=</span> <span class="n">optimization_request_form</span>
  268. <span class="bp">self</span><span class="o">.</span><span class="n">conversion_kwargs</span> <span class="o">=</span> <span class="n">kwargs</span>
  269. <span class="bp">self</span><span class="o">.</span><span class="n">ckpt_name</span> <span class="o">=</span> <span class="n">ckpt_name</span>
  270. <span class="bp">self</span><span class="o">.</span><span class="n">platform_client</span> <span class="o">=</span> <span class="n">DeciPlatformClient</span><span class="p">(</span><span class="s2">&quot;api.deci.ai&quot;</span><span class="p">,</span> <span class="mi">443</span><span class="p">,</span> <span class="n">https</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  271. <span class="bp">self</span><span class="o">.</span><span class="n">platform_client</span><span class="o">.</span><span class="n">login</span><span class="p">(</span><span class="n">token</span><span class="o">=</span><span class="n">auth_token</span><span class="p">)</span>
  272. <div class="viewcode-block" id="DeciLabUploadCallback.log_optimization_failed"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.DeciLabUploadCallback.log_optimization_failed">[docs]</a> <span class="nd">@staticmethod</span>
  273. <span class="k">def</span> <span class="nf">log_optimization_failed</span><span class="p">():</span>
  274. <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;We couldn&#39;t finish your model optimization. Visit https://console.deci.ai for details&quot;</span><span class="p">)</span></div>
  275. <div class="viewcode-block" id="DeciLabUploadCallback.upload_model"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.DeciLabUploadCallback.upload_model">[docs]</a> <span class="k">def</span> <span class="nf">upload_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">):</span>
  276. <span class="sd">&quot;&quot;&quot;</span>
  277. <span class="sd"> This function will upload the trained model to the Deci Lab</span>
  278. <span class="sd"> Args:</span>
  279. <span class="sd"> model: The resulting model from the training process</span>
  280. <span class="sd"> &quot;&quot;&quot;</span>
  281. <span class="bp">self</span><span class="o">.</span><span class="n">platform_client</span><span class="o">.</span><span class="n">add_model</span><span class="p">(</span>
  282. <span class="n">add_model_request</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">model_meta_data</span><span class="p">,</span>
  283. <span class="n">optimization_request</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">optimization_request_form</span><span class="p">,</span>
  284. <span class="n">local_loaded_model</span><span class="o">=</span><span class="n">model</span><span class="p">,</span>
  285. <span class="p">)</span></div>
  286. <div class="viewcode-block" id="DeciLabUploadCallback.get_optimization_status"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.DeciLabUploadCallback.get_optimization_status">[docs]</a> <span class="k">def</span> <span class="nf">get_optimization_status</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">optimized_model_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
  287. <span class="sd">&quot;&quot;&quot;</span>
  288. <span class="sd"> This function will do fetch the optimized version of the trained model and check on its benchmark status.</span>
  289. <span class="sd"> The status will be checked against the server every 30 seconds and the process will timeout after 30 minutes</span>
  290. <span class="sd"> or log about the successful optimization - whichever happens first.</span>
  291. <span class="sd"> Args:</span>
  292. <span class="sd"> optimized_model_name (str): Optimized model name</span>
  293. <span class="sd"> Returns:</span>
  294. <span class="sd"> bool: whether or not the optimized model has been benchmarked</span>
  295. <span class="sd"> &quot;&quot;&quot;</span>
  296. <span class="k">def</span> <span class="nf">handler</span><span class="p">(</span><span class="n">_signum</span><span class="p">,</span> <span class="n">_frame</span><span class="p">):</span>
  297. <span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="s2">&quot;Process timed out. Visit https://console.deci.ai for details&quot;</span><span class="p">)</span>
  298. <span class="k">return</span> <span class="kc">False</span>
  299. <span class="n">signal</span><span class="o">.</span><span class="n">signal</span><span class="p">(</span><span class="n">signal</span><span class="o">.</span><span class="n">SIGALRM</span><span class="p">,</span> <span class="n">handler</span><span class="p">)</span>
  300. <span class="n">signal</span><span class="o">.</span><span class="n">alarm</span><span class="p">(</span><span class="mi">1800</span><span class="p">)</span>
  301. <span class="n">finished</span> <span class="o">=</span> <span class="kc">False</span>
  302. <span class="k">while</span> <span class="ow">not</span> <span class="n">finished</span><span class="p">:</span>
  303. <span class="n">optimized_model</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">platform_client</span><span class="o">.</span><span class="n">get_model_by_name</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">optimized_model_name</span><span class="p">)</span><span class="o">.</span><span class="n">data</span>
  304. <span class="k">if</span> <span class="n">optimized_model</span><span class="o">.</span><span class="n">benchmark_state</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="n">ModelBenchmarkState</span><span class="o">.</span><span class="n">IN_PROGRESS</span><span class="p">,</span> <span class="n">ModelBenchmarkState</span><span class="o">.</span><span class="n">PENDING</span><span class="p">]:</span>
  305. <span class="n">finished</span> <span class="o">=</span> <span class="kc">True</span>
  306. <span class="k">else</span><span class="p">:</span>
  307. <span class="n">time</span><span class="o">.</span><span class="n">sleep</span><span class="p">(</span><span class="mi">30</span><span class="p">)</span>
  308. <span class="n">signal</span><span class="o">.</span><span class="n">alarm</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
  309. <span class="k">return</span> <span class="kc">True</span></div>
  310. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
  311. <span class="sd">&quot;&quot;&quot;</span>
  312. <span class="sd"> This function will attempt to upload the trained model and schedule an optimization for it.</span>
  313. <span class="sd"> Args:</span>
  314. <span class="sd"> context (PhaseContext): Training phase context</span>
  315. <span class="sd"> Returns:</span>
  316. <span class="sd"> bool: whether or not the optimized model has been benchmarked</span>
  317. <span class="sd"> &quot;&quot;&quot;</span>
  318. <span class="k">try</span><span class="p">:</span>
  319. <span class="n">model</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">net</span><span class="p">)</span>
  320. <span class="n">model_state_dict_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">ckpt_dir</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ckpt_name</span><span class="p">)</span>
  321. <span class="n">model_state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">model_state_dict_path</span><span class="p">)[</span><span class="s2">&quot;net&quot;</span><span class="p">]</span>
  322. <span class="n">model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">state_dict</span><span class="o">=</span><span class="n">model_state_dict</span><span class="p">)</span>
  323. <span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span>
  324. <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="s2">&quot;prep_model_for_conversion&quot;</span><span class="p">):</span>
  325. <span class="n">model</span><span class="o">.</span><span class="n">prep_model_for_conversion</span><span class="p">(</span><span class="n">input_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">model_meta_data</span><span class="o">.</span><span class="n">input_dimensions</span><span class="p">)</span>
  326. <span class="bp">self</span><span class="o">.</span><span class="n">upload_model</span><span class="p">(</span><span class="n">model</span><span class="o">=</span><span class="n">model</span><span class="p">)</span>
  327. <span class="n">model_name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_meta_data</span><span class="o">.</span><span class="n">name</span>
  328. <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Successfully added </span><span class="si">{</span><span class="n">model_name</span><span class="si">}</span><span class="s2"> to the model repository&quot;</span><span class="p">)</span>
  329. <span class="n">optimized_model_name</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">model_name</span><span class="si">}</span><span class="s2">_1_1&quot;</span>
  330. <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;We&#39;ll wait for the scheduled optimization to finish. Please don&#39;t close this window&quot;</span><span class="p">)</span>
  331. <span class="n">success</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_optimization_status</span><span class="p">(</span><span class="n">optimized_model_name</span><span class="o">=</span><span class="n">optimized_model_name</span><span class="p">)</span>
  332. <span class="k">if</span> <span class="n">success</span><span class="p">:</span>
  333. <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Successfully finished your model optimization. Visit https://console.deci.ai for details&quot;</span><span class="p">)</span>
  334. <span class="k">else</span><span class="p">:</span>
  335. <span class="n">DeciLabUploadCallback</span><span class="o">.</span><span class="n">log_optimization_failed</span><span class="p">()</span>
  336. <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>
  337. <span class="n">DeciLabUploadCallback</span><span class="o">.</span><span class="n">log_optimization_failed</span><span class="p">()</span>
  338. <span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="n">ex</span><span class="p">)</span></div>
  339. <div class="viewcode-block" id="LRCallbackBase"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.LRCallbackBase">[docs]</a><span class="k">class</span> <span class="nc">LRCallbackBase</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
  340. <span class="sd">&quot;&quot;&quot;</span>
  341. <span class="sd"> Base class for hard coded learning rate scheduling regimes, implemented as callbacks.</span>
  342. <span class="sd"> &quot;&quot;&quot;</span>
  343. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">phase</span><span class="p">,</span> <span class="n">initial_lr</span><span class="p">,</span> <span class="n">update_param_groups</span><span class="p">,</span> <span class="n">train_loader_len</span><span class="p">,</span> <span class="n">net</span><span class="p">,</span> <span class="n">training_params</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  344. <span class="nb">super</span><span class="p">(</span><span class="n">LRCallbackBase</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="p">)</span>
  345. <span class="bp">self</span><span class="o">.</span><span class="n">initial_lr</span> <span class="o">=</span> <span class="n">initial_lr</span>
  346. <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="n">initial_lr</span>
  347. <span class="bp">self</span><span class="o">.</span><span class="n">update_param_groups</span> <span class="o">=</span> <span class="n">update_param_groups</span>
  348. <span class="bp">self</span><span class="o">.</span><span class="n">train_loader_len</span> <span class="o">=</span> <span class="n">train_loader_len</span>
  349. <span class="bp">self</span><span class="o">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">net</span>
  350. <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span> <span class="o">=</span> <span class="n">training_params</span>
  351. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  352. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_lr_scheduling_enabled</span><span class="p">(</span><span class="n">context</span><span class="p">):</span>
  353. <span class="bp">self</span><span class="o">.</span><span class="n">perform_scheduling</span><span class="p">(</span><span class="n">context</span><span class="p">)</span>
  354. <div class="viewcode-block" id="LRCallbackBase.is_lr_scheduling_enabled"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.LRCallbackBase.is_lr_scheduling_enabled">[docs]</a> <span class="k">def</span> <span class="nf">is_lr_scheduling_enabled</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
  355. <span class="sd">&quot;&quot;&quot;</span>
  356. <span class="sd"> Predicate that controls whether to perform lr scheduling based on values in context.</span>
  357. <span class="sd"> @param context: PhaseContext: current phase&#39;s context.</span>
  358. <span class="sd"> @return: bool, whether to apply lr scheduling or not.</span>
  359. <span class="sd"> &quot;&quot;&quot;</span>
  360. <span class="k">raise</span> <span class="ne">NotImplementedError</span></div>
  361. <div class="viewcode-block" id="LRCallbackBase.perform_scheduling"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.LRCallbackBase.perform_scheduling">[docs]</a> <span class="k">def</span> <span class="nf">perform_scheduling</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
  362. <span class="sd">&quot;&quot;&quot;</span>
  363. <span class="sd"> Performs lr scheduling based on values in context.</span>
  364. <span class="sd"> @param context: PhaseContext: current phase&#39;s context.</span>
  365. <span class="sd"> &quot;&quot;&quot;</span>
  366. <span class="k">raise</span> <span class="ne">NotImplementedError</span></div>
  367. <div class="viewcode-block" id="LRCallbackBase.update_lr"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.LRCallbackBase.update_lr">[docs]</a> <span class="k">def</span> <span class="nf">update_lr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">batch_idx</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  368. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">update_param_groups</span><span class="p">:</span>
  369. <span class="n">param_groups</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">update_param_groups</span><span class="p">(</span>
  370. <span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">lr</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader_len</span>
  371. <span class="p">)</span>
  372. <span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span> <span class="o">=</span> <span class="n">param_groups</span>
  373. <span class="k">else</span><span class="p">:</span>
  374. <span class="c1"># UPDATE THE OPTIMIZERS PARAMETER</span>
  375. <span class="k">for</span> <span class="n">param_group</span> <span class="ow">in</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span>
  376. <span class="n">param_group</span><span class="p">[</span><span class="s2">&quot;lr&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lr</span></div></div>
  377. <div class="viewcode-block" id="WarmupLRCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.WarmupLRCallback">[docs]</a><span class="k">class</span> <span class="nc">WarmupLRCallback</span><span class="p">(</span><span class="n">LRCallbackBase</span><span class="p">):</span>
  378. <span class="sd">&quot;&quot;&quot;</span>
  379. <span class="sd"> LR scheduling callback for linear step warmup.</span>
  380. <span class="sd"> LR climbs from warmup_initial_lr with even steps to initial lr. When warmup_initial_lr is None- LR climb starts from</span>
  381. <span class="sd"> initial_lr/(1+warmup_epochs).</span>
  382. <span class="sd"> &quot;&quot;&quot;</span>
  383. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  384. <span class="nb">super</span><span class="p">(</span><span class="n">WarmupLRCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_EPOCH_START</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  385. <span class="bp">self</span><span class="o">.</span><span class="n">warmup_initial_lr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">warmup_initial_lr</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">initial_lr</span> <span class="o">/</span> <span class="p">(</span>
  386. <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">+</span> <span class="mi">1</span>
  387. <span class="p">)</span>
  388. <span class="bp">self</span><span class="o">.</span><span class="n">warmup_step_size</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">initial_lr</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">warmup_initial_lr</span><span class="p">)</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span>
  389. <div class="viewcode-block" id="WarmupLRCallback.perform_scheduling"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.WarmupLRCallback.perform_scheduling">[docs]</a> <span class="k">def</span> <span class="nf">perform_scheduling</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
  390. <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">warmup_initial_lr</span> <span class="o">+</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">warmup_step_size</span>
  391. <span class="bp">self</span><span class="o">.</span><span class="n">update_lr</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span></div>
  392. <div class="viewcode-block" id="WarmupLRCallback.is_lr_scheduling_enabled"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.WarmupLRCallback.is_lr_scheduling_enabled">[docs]</a> <span class="k">def</span> <span class="nf">is_lr_scheduling_enabled</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
  393. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">&gt;=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span></div></div>
  394. <div class="viewcode-block" id="StepLRCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.StepLRCallback">[docs]</a><span class="k">class</span> <span class="nc">StepLRCallback</span><span class="p">(</span><span class="n">LRCallbackBase</span><span class="p">):</span>
  395. <span class="sd">&quot;&quot;&quot;</span>
  396. <span class="sd"> Hard coded step learning rate scheduling (i.e at specific milestones).</span>
  397. <span class="sd"> &quot;&quot;&quot;</span>
  398. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lr_updates</span><span class="p">,</span> <span class="n">lr_decay_factor</span><span class="p">,</span> <span class="n">step_lr_update_freq</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  399. <span class="nb">super</span><span class="p">(</span><span class="n">StepLRCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_EPOCH_END</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  400. <span class="k">if</span> <span class="n">step_lr_update_freq</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">lr_updates</span><span class="p">):</span>
  401. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
  402. <span class="s2">&quot;Only one of [lr_updates, step_lr_update_freq] should be passed to StepLRCallback constructor&quot;</span>
  403. <span class="p">)</span>
  404. <span class="k">if</span> <span class="n">step_lr_update_freq</span><span class="p">:</span>
  405. <span class="n">max_epochs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_cooldown_epochs</span>
  406. <span class="n">warmup_epochs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span>
  407. <span class="n">lr_updates</span> <span class="o">=</span> <span class="p">[</span>
  408. <span class="nb">int</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">step_lr_update_freq</span> <span class="o">*</span> <span class="n">x</span><span class="p">))</span>
  409. <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">max_epochs</span><span class="p">)</span>
  410. <span class="k">if</span> <span class="n">warmup_epochs</span> <span class="o">&lt;=</span> <span class="nb">int</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">step_lr_update_freq</span> <span class="o">*</span> <span class="n">x</span><span class="p">))</span> <span class="o">&lt;</span> <span class="n">max_epochs</span>
  411. <span class="p">]</span>
  412. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_cooldown_epochs</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
  413. <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
  414. <span class="s2">&quot;Specific lr_updates were passed along with cooldown_epochs &gt; 0,&quot;</span> <span class="s2">&quot; cooldown will have no effect.&quot;</span>
  415. <span class="p">)</span>
  416. <span class="bp">self</span><span class="o">.</span><span class="n">lr_updates</span> <span class="o">=</span> <span class="n">lr_updates</span>
  417. <span class="bp">self</span><span class="o">.</span><span class="n">lr_decay_factor</span> <span class="o">=</span> <span class="n">lr_decay_factor</span>
  418. <div class="viewcode-block" id="StepLRCallback.perform_scheduling"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.StepLRCallback.perform_scheduling">[docs]</a> <span class="k">def</span> <span class="nf">perform_scheduling</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
  419. <span class="n">num_updates_passed</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">lr_updates</span> <span class="k">if</span> <span class="n">x</span> <span class="o">&lt;=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span><span class="p">]</span>
  420. <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">initial_lr</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">lr_decay_factor</span> <span class="o">**</span> <span class="nb">len</span><span class="p">(</span><span class="n">num_updates_passed</span><span class="p">)</span>
  421. <span class="bp">self</span><span class="o">.</span><span class="n">update_lr</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span></div>
  422. <div class="viewcode-block" id="StepLRCallback.is_lr_scheduling_enabled"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.StepLRCallback.is_lr_scheduling_enabled">[docs]</a> <span class="k">def</span> <span class="nf">is_lr_scheduling_enabled</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
  423. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">&lt;=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span></div></div>
  424. <div class="viewcode-block" id="ExponentialLRCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.ExponentialLRCallback">[docs]</a><span class="k">class</span> <span class="nc">ExponentialLRCallback</span><span class="p">(</span><span class="n">LRCallbackBase</span><span class="p">):</span>
  425. <span class="sd">&quot;&quot;&quot;</span>
  426. <span class="sd"> Exponential decay learning rate scheduling. Decays the learning rate by `lr_decay_factor` every epoch.</span>
  427. <span class="sd"> &quot;&quot;&quot;</span>
  428. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lr_decay_factor</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  429. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="o">=</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_BATCH_STEP</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  430. <span class="bp">self</span><span class="o">.</span><span class="n">lr_decay_factor</span> <span class="o">=</span> <span class="n">lr_decay_factor</span>
  431. <div class="viewcode-block" id="ExponentialLRCallback.perform_scheduling"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.ExponentialLRCallback.perform_scheduling">[docs]</a> <span class="k">def</span> <span class="nf">perform_scheduling</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
  432. <span class="n">effective_epoch</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span>
  433. <span class="n">current_iter</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader_len</span> <span class="o">*</span> <span class="n">effective_epoch</span> <span class="o">+</span> <span class="n">context</span><span class="o">.</span><span class="n">batch_idx</span>
  434. <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">initial_lr</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">lr_decay_factor</span> <span class="o">**</span> <span class="p">(</span><span class="n">current_iter</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader_len</span><span class="p">)</span>
  435. <span class="bp">self</span><span class="o">.</span><span class="n">update_lr</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">batch_idx</span><span class="p">)</span></div>
  436. <div class="viewcode-block" id="ExponentialLRCallback.is_lr_scheduling_enabled"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.ExponentialLRCallback.is_lr_scheduling_enabled">[docs]</a> <span class="k">def</span> <span class="nf">is_lr_scheduling_enabled</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
  437. <span class="n">post_warmup_epochs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_cooldown_epochs</span>
  438. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">&lt;=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">&lt;</span> <span class="n">post_warmup_epochs</span></div></div>
  439. <div class="viewcode-block" id="PolyLRCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.PolyLRCallback">[docs]</a><span class="k">class</span> <span class="nc">PolyLRCallback</span><span class="p">(</span><span class="n">LRCallbackBase</span><span class="p">):</span>
  440. <span class="sd">&quot;&quot;&quot;</span>
  441. <span class="sd"> Hard coded polynomial decay learning rate scheduling (i.e at specific milestones).</span>
  442. <span class="sd"> &quot;&quot;&quot;</span>
  443. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_epochs</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  444. <span class="nb">super</span><span class="p">(</span><span class="n">PolyLRCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_BATCH_STEP</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  445. <span class="bp">self</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">=</span> <span class="n">max_epochs</span>
  446. <div class="viewcode-block" id="PolyLRCallback.perform_scheduling"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.PolyLRCallback.perform_scheduling">[docs]</a> <span class="k">def</span> <span class="nf">perform_scheduling</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
  447. <span class="n">effective_epoch</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span>
  448. <span class="n">effective_max_epochs</span> <span class="o">=</span> <span class="p">(</span>
  449. <span class="bp">self</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_cooldown_epochs</span>
  450. <span class="p">)</span>
  451. <span class="n">current_iter</span> <span class="o">=</span> <span class="p">(</span>
  452. <span class="bp">self</span><span class="o">.</span><span class="n">train_loader_len</span> <span class="o">*</span> <span class="n">effective_epoch</span> <span class="o">+</span> <span class="n">context</span><span class="o">.</span><span class="n">batch_idx</span>
  453. <span class="p">)</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">batch_accumulate</span>
  454. <span class="n">max_iter</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader_len</span> <span class="o">*</span> <span class="n">effective_max_epochs</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">batch_accumulate</span>
  455. <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">initial_lr</span> <span class="o">*</span> <span class="nb">pow</span><span class="p">((</span><span class="mf">1.0</span> <span class="o">-</span> <span class="p">(</span><span class="n">current_iter</span> <span class="o">/</span> <span class="n">max_iter</span><span class="p">)),</span> <span class="mf">0.9</span><span class="p">)</span>
  456. <span class="bp">self</span><span class="o">.</span><span class="n">update_lr</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">batch_idx</span><span class="p">)</span></div>
  457. <div class="viewcode-block" id="PolyLRCallback.is_lr_scheduling_enabled"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.PolyLRCallback.is_lr_scheduling_enabled">[docs]</a> <span class="k">def</span> <span class="nf">is_lr_scheduling_enabled</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
  458. <span class="n">post_warmup_epochs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_cooldown_epochs</span>
  459. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">&lt;=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">&lt;</span> <span class="n">post_warmup_epochs</span></div></div>
  460. <div class="viewcode-block" id="CosineLRCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.CosineLRCallback">[docs]</a><span class="k">class</span> <span class="nc">CosineLRCallback</span><span class="p">(</span><span class="n">LRCallbackBase</span><span class="p">):</span>
  461. <span class="sd">&quot;&quot;&quot;</span>
  462. <span class="sd"> Hard coded step Cosine anealing learning rate scheduling.</span>
  463. <span class="sd"> &quot;&quot;&quot;</span>
  464. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_epochs</span><span class="p">,</span> <span class="n">cosine_final_lr_ratio</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  465. <span class="nb">super</span><span class="p">(</span><span class="n">CosineLRCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_BATCH_STEP</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  466. <span class="bp">self</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">=</span> <span class="n">max_epochs</span>
  467. <span class="bp">self</span><span class="o">.</span><span class="n">cosine_final_lr_ratio</span> <span class="o">=</span> <span class="n">cosine_final_lr_ratio</span>
  468. <div class="viewcode-block" id="CosineLRCallback.perform_scheduling"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.CosineLRCallback.perform_scheduling">[docs]</a> <span class="k">def</span> <span class="nf">perform_scheduling</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
  469. <span class="n">effective_epoch</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span>
  470. <span class="n">effective_max_epochs</span> <span class="o">=</span> <span class="p">(</span>
  471. <span class="bp">self</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_cooldown_epochs</span>
  472. <span class="p">)</span>
  473. <span class="n">current_iter</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader_len</span> <span class="o">*</span> <span class="n">effective_epoch</span> <span class="o">+</span> <span class="n">context</span><span class="o">.</span><span class="n">batch_idx</span>
  474. <span class="n">max_iter</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_loader_len</span> <span class="o">*</span> <span class="n">effective_max_epochs</span>
  475. <span class="n">lr</span> <span class="o">=</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">initial_lr</span> <span class="o">*</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">math</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">current_iter</span> <span class="o">/</span> <span class="p">(</span><span class="n">max_iter</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span><span class="p">))</span>
  476. <span class="c1"># the cosine starts from initial_lr and reaches initial_lr * cosine_final_lr_ratio in last epoch</span>
  477. <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="n">lr</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">cosine_final_lr_ratio</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">initial_lr</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">cosine_final_lr_ratio</span><span class="p">)</span>
  478. <span class="bp">self</span><span class="o">.</span><span class="n">update_lr</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">batch_idx</span><span class="p">)</span></div>
  479. <div class="viewcode-block" id="CosineLRCallback.is_lr_scheduling_enabled"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.CosineLRCallback.is_lr_scheduling_enabled">[docs]</a> <span class="k">def</span> <span class="nf">is_lr_scheduling_enabled</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
  480. <span class="n">post_warmup_epochs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_cooldown_epochs</span>
  481. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">&lt;=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">&lt;</span> <span class="n">post_warmup_epochs</span></div></div>
  482. <div class="viewcode-block" id="FunctionLRCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.FunctionLRCallback">[docs]</a><span class="k">class</span> <span class="nc">FunctionLRCallback</span><span class="p">(</span><span class="n">LRCallbackBase</span><span class="p">):</span>
  483. <span class="sd">&quot;&quot;&quot;</span>
  484. <span class="sd"> Hard coded rate scheduling for user defined lr scheduling function.</span>
  485. <span class="sd"> &quot;&quot;&quot;</span>
  486. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">max_epochs</span><span class="p">,</span> <span class="n">lr_schedule_function</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  487. <span class="nb">super</span><span class="p">(</span><span class="n">FunctionLRCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_BATCH_STEP</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  488. <span class="k">assert</span> <span class="n">callable</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lr_schedule_function</span><span class="p">),</span> <span class="s2">&quot;self.lr_function must be callable&quot;</span>
  489. <span class="bp">self</span><span class="o">.</span><span class="n">lr_schedule_function</span> <span class="o">=</span> <span class="n">lr_schedule_function</span>
  490. <span class="bp">self</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">=</span> <span class="n">max_epochs</span>
  491. <div class="viewcode-block" id="FunctionLRCallback.is_lr_scheduling_enabled"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.FunctionLRCallback.is_lr_scheduling_enabled">[docs]</a> <span class="k">def</span> <span class="nf">is_lr_scheduling_enabled</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
  492. <span class="n">post_warmup_epochs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_cooldown_epochs</span>
  493. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">&lt;=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">&lt;</span> <span class="n">post_warmup_epochs</span></div>
  494. <div class="viewcode-block" id="FunctionLRCallback.perform_scheduling"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.FunctionLRCallback.perform_scheduling">[docs]</a> <span class="k">def</span> <span class="nf">perform_scheduling</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span>
  495. <span class="n">effective_epoch</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span>
  496. <span class="n">effective_max_epochs</span> <span class="o">=</span> <span class="p">(</span>
  497. <span class="bp">self</span><span class="o">.</span><span class="n">max_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">training_params</span><span class="o">.</span><span class="n">lr_cooldown_epochs</span>
  498. <span class="p">)</span>
  499. <span class="bp">self</span><span class="o">.</span><span class="n">lr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lr_schedule_function</span><span class="p">(</span>
  500. <span class="n">initial_lr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">initial_lr</span><span class="p">,</span>
  501. <span class="n">epoch</span><span class="o">=</span><span class="n">effective_epoch</span><span class="p">,</span>
  502. <span class="nb">iter</span><span class="o">=</span><span class="n">context</span><span class="o">.</span><span class="n">batch_idx</span><span class="p">,</span>
  503. <span class="n">max_epoch</span><span class="o">=</span><span class="n">effective_max_epochs</span><span class="p">,</span>
  504. <span class="n">iters_per_epoch</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">train_loader_len</span><span class="p">,</span>
  505. <span class="p">)</span>
  506. <span class="bp">self</span><span class="o">.</span><span class="n">update_lr</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">batch_idx</span><span class="p">)</span></div></div>
  507. <div class="viewcode-block" id="IllegalLRSchedulerMetric"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.IllegalLRSchedulerMetric">[docs]</a><span class="k">class</span> <span class="nc">IllegalLRSchedulerMetric</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
  508. <span class="sd">&quot;&quot;&quot;Exception raised illegal combination of training parameters.</span>
  509. <span class="sd"> Attributes:</span>
  510. <span class="sd"> message -- explanation of the error</span>
  511. <span class="sd"> &quot;&quot;&quot;</span>
  512. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">metric_name</span><span class="p">,</span> <span class="n">metrics_dict</span><span class="p">):</span>
  513. <span class="bp">self</span><span class="o">.</span><span class="n">message</span> <span class="o">=</span> <span class="p">(</span>
  514. <span class="s2">&quot;Illegal metric name: &quot;</span> <span class="o">+</span> <span class="n">metric_name</span> <span class="o">+</span> <span class="s2">&quot;. Expected one of metics_dics keys: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">metrics_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span>
  515. <span class="p">)</span>
  516. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">message</span><span class="p">)</span></div>
  517. <div class="viewcode-block" id="LRSchedulerCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.LRSchedulerCallback">[docs]</a><span class="k">class</span> <span class="nc">LRSchedulerCallback</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
  518. <span class="sd">&quot;&quot;&quot;</span>
  519. <span class="sd"> Learning rate scheduler callback.</span>
  520. <span class="sd"> Attributes:</span>
  521. <span class="sd"> scheduler: torch.optim._LRScheduler, the learning rate scheduler to be called step() with.</span>
  522. <span class="sd"> metric_name: str, (default=None) the metric name for ReduceLROnPlateau learning rate scheduler.</span>
  523. <span class="sd"> When passing __call__ a metrics_dict, with a key=self.metric_name, the value of that metric will monitored</span>
  524. <span class="sd"> for ReduceLROnPlateau (i.e step(metrics_dict[self.metric_name]).</span>
  525. <span class="sd"> &quot;&quot;&quot;</span>
  526. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">scheduler</span><span class="p">,</span> <span class="n">phase</span><span class="p">,</span> <span class="n">metric_name</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  527. <span class="nb">super</span><span class="p">(</span><span class="n">LRSchedulerCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="p">)</span>
  528. <span class="bp">self</span><span class="o">.</span><span class="n">scheduler</span> <span class="o">=</span> <span class="n">scheduler</span>
  529. <span class="bp">self</span><span class="o">.</span><span class="n">metric_name</span> <span class="o">=</span> <span class="n">metric_name</span>
  530. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
  531. <span class="k">if</span> <span class="n">context</span><span class="o">.</span><span class="n">lr_warmup_epochs</span> <span class="o">&lt;=</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span><span class="p">:</span>
  532. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">metric_name</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">metric_name</span> <span class="ow">in</span> <span class="n">context</span><span class="o">.</span><span class="n">metrics_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
  533. <span class="bp">self</span><span class="o">.</span><span class="n">scheduler</span><span class="o">.</span><span class="n">step</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">metrics_dict</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_name</span><span class="p">])</span>
  534. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">metric_name</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  535. <span class="bp">self</span><span class="o">.</span><span class="n">scheduler</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
  536. <span class="k">else</span><span class="p">:</span>
  537. <span class="k">raise</span> <span class="n">IllegalLRSchedulerMetric</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_name</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">metrics_dict</span><span class="p">)</span>
  538. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  539. <span class="k">return</span> <span class="s2">&quot;LRSchedulerCallback: &quot;</span> <span class="o">+</span> <span class="nb">repr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">scheduler</span><span class="p">)</span></div>
  540. <div class="viewcode-block" id="MetricsUpdateCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.MetricsUpdateCallback">[docs]</a><span class="k">class</span> <span class="nc">MetricsUpdateCallback</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
  541. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">phase</span><span class="p">:</span> <span class="n">Phase</span><span class="p">):</span>
  542. <span class="nb">super</span><span class="p">(</span><span class="n">MetricsUpdateCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="p">)</span>
  543. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
  544. <span class="n">context</span><span class="o">.</span><span class="n">metrics_compute_fn</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="o">**</span><span class="n">context</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">)</span>
  545. <span class="k">if</span> <span class="n">context</span><span class="o">.</span><span class="n">criterion</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  546. <span class="n">context</span><span class="o">.</span><span class="n">loss_avg_meter</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">loss_log_items</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">inputs</span><span class="p">))</span></div>
  547. <div class="viewcode-block" id="KDModelMetricsUpdateCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.KDModelMetricsUpdateCallback">[docs]</a><span class="k">class</span> <span class="nc">KDModelMetricsUpdateCallback</span><span class="p">(</span><span class="n">MetricsUpdateCallback</span><span class="p">):</span>
  548. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">phase</span><span class="p">:</span> <span class="n">Phase</span><span class="p">):</span>
  549. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="o">=</span><span class="n">phase</span><span class="p">)</span>
  550. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
  551. <span class="n">metrics_compute_fn_kwargs</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span><span class="o">.</span><span class="n">student_output</span> <span class="k">if</span> <span class="n">k</span> <span class="o">==</span> <span class="s2">&quot;preds&quot;</span> <span class="k">else</span> <span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">context</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
  552. <span class="n">context</span><span class="o">.</span><span class="n">metrics_compute_fn</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="o">**</span><span class="n">metrics_compute_fn_kwargs</span><span class="p">)</span>
  553. <span class="k">if</span> <span class="n">context</span><span class="o">.</span><span class="n">criterion</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  554. <span class="n">context</span><span class="o">.</span><span class="n">loss_avg_meter</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">loss_log_items</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">inputs</span><span class="p">))</span></div>
  555. <div class="viewcode-block" id="PhaseContextTestCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.PhaseContextTestCallback">[docs]</a><span class="k">class</span> <span class="nc">PhaseContextTestCallback</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
  556. <span class="sd">&quot;&quot;&quot;</span>
  557. <span class="sd"> A callback that saves the phase context the for testing.</span>
  558. <span class="sd"> &quot;&quot;&quot;</span>
  559. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">phase</span><span class="p">:</span> <span class="n">Phase</span><span class="p">):</span>
  560. <span class="nb">super</span><span class="p">(</span><span class="n">PhaseContextTestCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="p">)</span>
  561. <span class="bp">self</span><span class="o">.</span><span class="n">context</span> <span class="o">=</span> <span class="kc">None</span>
  562. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
  563. <span class="bp">self</span><span class="o">.</span><span class="n">context</span> <span class="o">=</span> <span class="n">context</span></div>
  564. <div class="viewcode-block" id="DetectionVisualizationCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.DetectionVisualizationCallback">[docs]</a><span class="k">class</span> <span class="nc">DetectionVisualizationCallback</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
  565. <span class="sd">&quot;&quot;&quot;</span>
  566. <span class="sd"> A callback that adds a visualization of a batch of detection predictions to context.sg_logger</span>
  567. <span class="sd"> Attributes:</span>
  568. <span class="sd"> freq: frequency (in epochs) to perform this callback.</span>
  569. <span class="sd"> batch_idx: batch index to perform visualization for.</span>
  570. <span class="sd"> classes: class list of the dataset.</span>
  571. <span class="sd"> last_img_idx_in_batch: Last image index to add to log. (default=-1, will take entire batch).</span>
  572. <span class="sd"> &quot;&quot;&quot;</span>
  573. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
  574. <span class="bp">self</span><span class="p">,</span>
  575. <span class="n">phase</span><span class="p">:</span> <span class="n">Phase</span><span class="p">,</span>
  576. <span class="n">freq</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
  577. <span class="n">post_prediction_callback</span><span class="p">:</span> <span class="n">DetectionPostPredictionCallback</span><span class="p">,</span>
  578. <span class="n">classes</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span>
  579. <span class="n">batch_idx</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
  580. <span class="n">last_img_idx_in_batch</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span>
  581. <span class="p">):</span>
  582. <span class="nb">super</span><span class="p">(</span><span class="n">DetectionVisualizationCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="p">)</span>
  583. <span class="bp">self</span><span class="o">.</span><span class="n">freq</span> <span class="o">=</span> <span class="n">freq</span>
  584. <span class="bp">self</span><span class="o">.</span><span class="n">post_prediction_callback</span> <span class="o">=</span> <span class="n">post_prediction_callback</span>
  585. <span class="bp">self</span><span class="o">.</span><span class="n">batch_idx</span> <span class="o">=</span> <span class="n">batch_idx</span>
  586. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">classes</span>
  587. <span class="bp">self</span><span class="o">.</span><span class="n">last_img_idx_in_batch</span> <span class="o">=</span> <span class="n">last_img_idx_in_batch</span>
  588. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
  589. <span class="k">if</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">freq</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">context</span><span class="o">.</span><span class="n">batch_idx</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_idx</span><span class="p">:</span>
  590. <span class="c1"># SOME CALCULATIONS ARE IN-PLACE IN NMS, SO CLONE THE PREDICTIONS</span>
  591. <span class="n">preds</span> <span class="o">=</span> <span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">preds</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="kc">None</span><span class="p">)</span>
  592. <span class="n">preds</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">post_prediction_callback</span><span class="p">(</span><span class="n">preds</span><span class="p">)</span>
  593. <span class="n">batch_imgs</span> <span class="o">=</span> <span class="n">DetectionVisualization</span><span class="o">.</span><span class="n">visualize_batch</span><span class="p">(</span>
  594. <span class="n">context</span><span class="o">.</span><span class="n">inputs</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">target</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_idx</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">classes</span>
  595. <span class="p">)</span>
  596. <span class="n">batch_imgs</span> <span class="o">=</span> <span class="p">[</span><span class="n">cv2</span><span class="o">.</span><span class="n">cvtColor</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">cv2</span><span class="o">.</span><span class="n">COLOR_BGR2RGB</span><span class="p">)</span> <span class="k">for</span> <span class="n">image</span> <span class="ow">in</span> <span class="n">batch_imgs</span><span class="p">]</span>
  597. <span class="n">batch_imgs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">batch_imgs</span><span class="p">)</span>
  598. <span class="n">tag</span> <span class="o">=</span> <span class="s2">&quot;batch_&quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_idx</span><span class="p">)</span> <span class="o">+</span> <span class="s2">&quot;_images&quot;</span>
  599. <span class="n">context</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_images</span><span class="p">(</span>
  600. <span class="n">tag</span><span class="o">=</span><span class="n">tag</span><span class="p">,</span> <span class="n">images</span><span class="o">=</span><span class="n">batch_imgs</span><span class="p">[:</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_img_idx_in_batch</span><span class="p">],</span> <span class="n">global_step</span><span class="o">=</span><span class="n">context</span><span class="o">.</span><span class="n">epoch</span><span class="p">,</span> <span class="n">data_format</span><span class="o">=</span><span class="s2">&quot;NHWC&quot;</span>
  601. <span class="p">)</span></div>
  602. <div class="viewcode-block" id="BinarySegmentationVisualizationCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.BinarySegmentationVisualizationCallback">[docs]</a><span class="k">class</span> <span class="nc">BinarySegmentationVisualizationCallback</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
  603. <span class="sd">&quot;&quot;&quot;</span>
  604. <span class="sd"> A callback that adds a visualization of a batch of segmentation predictions to context.sg_logger</span>
  605. <span class="sd"> Attributes:</span>
  606. <span class="sd"> freq: frequency (in epochs) to perform this callback.</span>
  607. <span class="sd"> batch_idx: batch index to perform visualization for.</span>
  608. <span class="sd"> last_img_idx_in_batch: Last image index to add to log. (default=-1, will take entire batch).</span>
  609. <span class="sd"> &quot;&quot;&quot;</span>
  610. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">phase</span><span class="p">:</span> <span class="n">Phase</span><span class="p">,</span> <span class="n">freq</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">last_img_idx_in_batch</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">):</span>
  611. <span class="nb">super</span><span class="p">(</span><span class="n">BinarySegmentationVisualizationCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="p">)</span>
  612. <span class="bp">self</span><span class="o">.</span><span class="n">freq</span> <span class="o">=</span> <span class="n">freq</span>
  613. <span class="bp">self</span><span class="o">.</span><span class="n">batch_idx</span> <span class="o">=</span> <span class="n">batch_idx</span>
  614. <span class="bp">self</span><span class="o">.</span><span class="n">last_img_idx_in_batch</span> <span class="o">=</span> <span class="n">last_img_idx_in_batch</span>
  615. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
  616. <span class="k">if</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">freq</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">context</span><span class="o">.</span><span class="n">batch_idx</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_idx</span><span class="p">:</span>
  617. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">preds</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>
  618. <span class="n">preds</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">preds</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
  619. <span class="k">else</span><span class="p">:</span>
  620. <span class="n">preds</span> <span class="o">=</span> <span class="n">context</span><span class="o">.</span><span class="n">preds</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
  621. <span class="n">batch_imgs</span> <span class="o">=</span> <span class="n">BinarySegmentationVisualization</span><span class="o">.</span><span class="n">visualize_batch</span><span class="p">(</span>
  622. <span class="n">context</span><span class="o">.</span><span class="n">inputs</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">context</span><span class="o">.</span><span class="n">target</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_idx</span>
  623. <span class="p">)</span>
  624. <span class="n">batch_imgs</span> <span class="o">=</span> <span class="p">[</span><span class="n">cv2</span><span class="o">.</span><span class="n">cvtColor</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">cv2</span><span class="o">.</span><span class="n">COLOR_BGR2RGB</span><span class="p">)</span> <span class="k">for</span> <span class="n">image</span> <span class="ow">in</span> <span class="n">batch_imgs</span><span class="p">]</span>
  625. <span class="n">batch_imgs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">batch_imgs</span><span class="p">)</span>
  626. <span class="n">tag</span> <span class="o">=</span> <span class="s2">&quot;batch_&quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_idx</span><span class="p">)</span> <span class="o">+</span> <span class="s2">&quot;_images&quot;</span>
  627. <span class="n">context</span><span class="o">.</span><span class="n">sg_logger</span><span class="o">.</span><span class="n">add_images</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="n">tag</span><span class="p">,</span> <span class="n">images</span><span class="o">=</span><span class="n">batch_imgs</span><span class="p">[:</span><span class="bp">self</span><span class="o">.</span><span class="n">last_img_idx_in_batch</span><span class="p">],</span>
  628. <span class="n">global_step</span><span class="o">=</span><span class="n">context</span><span class="o">.</span><span class="n">epoch</span><span class="p">,</span> <span class="n">data_format</span><span class="o">=</span><span class="s1">&#39;NHWC&#39;</span><span class="p">)</span></div>
  629. <div class="viewcode-block" id="TrainingStageSwitchCallbackBase"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.TrainingStageSwitchCallbackBase">[docs]</a><span class="k">class</span> <span class="nc">TrainingStageSwitchCallbackBase</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
  630. <span class="sd">&quot;&quot;&quot;</span>
  631. <span class="sd"> TrainingStageSwitchCallback</span>
  632. <span class="sd"> A phase callback that is called at a specific epoch (epoch start) to support multi-stage training.</span>
  633. <span class="sd"> It does so by manipulating the objects inside the context.</span>
  634. <span class="sd"> Attributes:</span>
  635. <span class="sd"> next_stage_start_epoch: int, the epoch idx to apply the stage change.</span>
  636. <span class="sd"> &quot;&quot;&quot;</span>
  637. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">next_stage_start_epoch</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
  638. <span class="nb">super</span><span class="p">(</span><span class="n">TrainingStageSwitchCallbackBase</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="o">=</span><span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_EPOCH_START</span><span class="p">)</span>
  639. <span class="bp">self</span><span class="o">.</span><span class="n">next_stage_start_epoch</span> <span class="o">=</span> <span class="n">next_stage_start_epoch</span>
  640. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
  641. <span class="k">if</span> <span class="n">context</span><span class="o">.</span><span class="n">epoch</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">next_stage_start_epoch</span><span class="p">:</span>
  642. <span class="bp">self</span><span class="o">.</span><span class="n">apply_stage_change</span><span class="p">(</span><span class="n">context</span><span class="p">)</span>
  643. <div class="viewcode-block" id="TrainingStageSwitchCallbackBase.apply_stage_change"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.TrainingStageSwitchCallbackBase.apply_stage_change">[docs]</a> <span class="k">def</span> <span class="nf">apply_stage_change</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
  644. <span class="sd">&quot;&quot;&quot;</span>
  645. <span class="sd"> This method is called when the callback is fired on the next_stage_start_epoch,</span>
  646. <span class="sd"> and holds the stage change logic that should be applied to the context&#39;s objects.</span>
  647. <span class="sd"> :param context: PhaseContext, context of current phase</span>
  648. <span class="sd"> &quot;&quot;&quot;</span>
  649. <span class="k">raise</span> <span class="ne">NotImplementedError</span></div></div>
  650. <div class="viewcode-block" id="YoloXTrainingStageSwitchCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.YoloXTrainingStageSwitchCallback">[docs]</a><span class="k">class</span> <span class="nc">YoloXTrainingStageSwitchCallback</span><span class="p">(</span><span class="n">TrainingStageSwitchCallbackBase</span><span class="p">):</span>
  651. <span class="sd">&quot;&quot;&quot;</span>
  652. <span class="sd"> YoloXTrainingStageSwitchCallback</span>
  653. <span class="sd"> Training stage switch for YoloX training.</span>
  654. <span class="sd"> Disables mosaic, and manipulates YoloX loss to use L1.</span>
  655. <span class="sd"> &quot;&quot;&quot;</span>
  656. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">next_stage_start_epoch</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">285</span><span class="p">):</span>
  657. <span class="nb">super</span><span class="p">(</span><span class="n">YoloXTrainingStageSwitchCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">next_stage_start_epoch</span><span class="o">=</span><span class="n">next_stage_start_epoch</span><span class="p">)</span>
  658. <div class="viewcode-block" id="YoloXTrainingStageSwitchCallback.apply_stage_change"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.YoloXTrainingStageSwitchCallback.apply_stage_change">[docs]</a> <span class="k">def</span> <span class="nf">apply_stage_change</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
  659. <span class="k">for</span> <span class="n">transform</span> <span class="ow">in</span> <span class="n">context</span><span class="o">.</span><span class="n">train_loader</span><span class="o">.</span><span class="n">dataset</span><span class="o">.</span><span class="n">transforms</span><span class="p">:</span>
  660. <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">transform</span><span class="p">,</span> <span class="s2">&quot;close&quot;</span><span class="p">):</span>
  661. <span class="n">transform</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
  662. <span class="nb">iter</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">train_loader</span><span class="p">)</span>
  663. <span class="n">context</span><span class="o">.</span><span class="n">criterion</span><span class="o">.</span><span class="n">use_l1</span> <span class="o">=</span> <span class="kc">True</span></div></div>
  664. <div class="viewcode-block" id="CallbackHandler"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.CallbackHandler">[docs]</a><span class="k">class</span> <span class="nc">CallbackHandler</span><span class="p">:</span>
  665. <span class="sd">&quot;&quot;&quot;</span>
  666. <span class="sd"> Runs all callbacks who&#39;s phase attribute equals to the given phase.</span>
  667. <span class="sd"> Attributes:</span>
  668. <span class="sd"> callbacks: List[PhaseCallback]. Callbacks to be run.</span>
  669. <span class="sd"> &quot;&quot;&quot;</span>
  670. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">callbacks</span><span class="p">):</span>
  671. <span class="bp">self</span><span class="o">.</span><span class="n">callbacks</span> <span class="o">=</span> <span class="n">callbacks</span>
  672. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">phase</span><span class="p">:</span> <span class="n">Phase</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
  673. <span class="k">for</span> <span class="n">callback</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">callbacks</span><span class="p">:</span>
  674. <span class="k">if</span> <span class="n">callback</span><span class="o">.</span><span class="n">phase</span> <span class="o">==</span> <span class="n">phase</span><span class="p">:</span>
  675. <span class="n">callback</span><span class="p">(</span><span class="n">context</span><span class="p">)</span></div>
  676. <span class="c1"># DICT FOR LEGACY LR HARD-CODED REGIMES, WILL BE DELETED IN THE FUTURE</span>
  677. <span class="n">LR_SCHEDULERS_CLS_DICT</span> <span class="o">=</span> <span class="p">{</span>
  678. <span class="s2">&quot;step&quot;</span><span class="p">:</span> <span class="n">StepLRCallback</span><span class="p">,</span>
  679. <span class="s2">&quot;poly&quot;</span><span class="p">:</span> <span class="n">PolyLRCallback</span><span class="p">,</span>
  680. <span class="s2">&quot;cosine&quot;</span><span class="p">:</span> <span class="n">CosineLRCallback</span><span class="p">,</span>
  681. <span class="s2">&quot;exp&quot;</span><span class="p">:</span> <span class="n">ExponentialLRCallback</span><span class="p">,</span>
  682. <span class="s2">&quot;function&quot;</span><span class="p">:</span> <span class="n">FunctionLRCallback</span><span class="p">,</span>
  683. <span class="p">}</span>
  684. <span class="n">LR_WARMUP_CLS_DICT</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;linear_step&quot;</span><span class="p">:</span> <span class="n">WarmupLRCallback</span><span class="p">}</span>
  685. <div class="viewcode-block" id="TestLRCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.callbacks.TestLRCallback">[docs]</a><span class="k">class</span> <span class="nc">TestLRCallback</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
  686. <span class="sd">&quot;&quot;&quot;</span>
  687. <span class="sd"> Phase callback that collects the learning rates in lr_placeholder at the end of each epoch (used for testing). In</span>
  688. <span class="sd"> the case of multiple parameter groups (i.e multiple learning rates) the learning rate is collected from the first</span>
  689. <span class="sd"> one. The phase is VALIDATION_EPOCH_END to ensure all lr updates have been performed before calling this callback.</span>
  690. <span class="sd"> &quot;&quot;&quot;</span>
  691. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">lr_placeholder</span><span class="p">):</span>
  692. <span class="nb">super</span><span class="p">(</span><span class="n">TestLRCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">VALIDATION_EPOCH_END</span><span class="p">)</span>
  693. <span class="bp">self</span><span class="o">.</span><span class="n">lr_placeholder</span> <span class="o">=</span> <span class="n">lr_placeholder</span>
  694. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
  695. <span class="bp">self</span><span class="o">.</span><span class="n">lr_placeholder</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">param_groups</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="s2">&quot;lr&quot;</span><span class="p">])</span></div>
  696. </pre></div>
  697. </div>
  698. </div>
  699. <footer>
  700. <hr/>
  701. <div role="contentinfo">
  702. <p>&#169; Copyright 2021, SuperGradients team.</p>
  703. </div>
  704. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  705. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  706. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  707. </footer>
  708. </div>
  709. </div>
  710. </section>
  711. </div>
  712. <script>
  713. jQuery(function () {
  714. SphinxRtdTheme.Navigation.enable(true);
  715. });
  716. </script>
  717. </body>
  718. </html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.utils.checkpoint_utils &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.utils.checkpoint_utils &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -89,17 +91,40 @@
 <span></span><span class="kn">import</span> <span class="nn">os</span>
 <span></span><span class="kn">import</span> <span class="nn">os</span>
 <span class="kn">import</span> <span class="nn">tempfile</span>
 <span class="kn">import</span> <span class="nn">tempfile</span>
 <span class="kn">import</span> <span class="nn">pkg_resources</span>
 <span class="kn">import</span> <span class="nn">pkg_resources</span>
+
 <span class="kn">import</span> <span class="nn">torch</span>
 <span class="kn">import</span> <span class="nn">torch</span>
+
+<span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common</span> <span class="kn">import</span> <span class="n">explicit_params_validation</span><span class="p">,</span> <span class="n">ADNNModelRepositoryDataInterfaces</span>
 <span class="kn">from</span> <span class="nn">super_gradients.common</span> <span class="kn">import</span> <span class="n">explicit_params_validation</span><span class="p">,</span> <span class="n">ADNNModelRepositoryDataInterfaces</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.pretrained_models</span> <span class="kn">import</span> <span class="n">MODEL_URLS</span>
 <span class="kn">from</span> <span class="nn">super_gradients.training.pretrained_models</span> <span class="kn">import</span> <span class="n">MODEL_URLS</span>
+<span class="kn">from</span> <span class="nn">super_gradients.common.environment</span> <span class="kn">import</span> <span class="n">environment_config</span>
+
 <span class="k">try</span><span class="p">:</span>
 <span class="k">try</span><span class="p">:</span>
     <span class="kn">from</span> <span class="nn">torch.hub</span> <span class="kn">import</span> <span class="n">download_url_to_file</span><span class="p">,</span> <span class="n">load_state_dict_from_url</span>
     <span class="kn">from</span> <span class="nn">torch.hub</span> <span class="kn">import</span> <span class="n">download_url_to_file</span><span class="p">,</span> <span class="n">load_state_dict_from_url</span>
 <span class="k">except</span> <span class="p">(</span><span class="ne">ModuleNotFoundError</span><span class="p">,</span> <span class="ne">ImportError</span><span class="p">,</span> <span class="ne">NameError</span><span class="p">):</span>
 <span class="k">except</span> <span class="p">(</span><span class="ne">ModuleNotFoundError</span><span class="p">,</span> <span class="ne">ImportError</span><span class="p">,</span> <span class="ne">NameError</span><span class="p">):</span>
     <span class="kn">from</span> <span class="nn">torch.hub</span> <span class="kn">import</span> <span class="n">_download_url_to_file</span> <span class="k">as</span> <span class="n">download_url_to_file</span>
     <span class="kn">from</span> <span class="nn">torch.hub</span> <span class="kn">import</span> <span class="n">_download_url_to_file</span> <span class="k">as</span> <span class="n">download_url_to_file</span>
 
 
 
 
-<div class="viewcode-block" id="get_ckpt_local_path"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.get_ckpt_local_path">[docs]</a><span class="k">def</span> <span class="nf">get_ckpt_local_path</span><span class="p">(</span><span class="n">source_ckpt_folder_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">:</span> <span class="nb">str</spa
-                        <span class="n">overwrite_local_checkpoint</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> <span class="n">load_weights_only</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span>
+<span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
+
+
+<span class="k">def</span> <span class="nf">get_checkpoints_dir_path</span><span class="p">(</span><span class="n">experiment_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">ckpt_root_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
+    <span class="sd">&quot;&quot;&quot;Creating the checkpoint directory of a given experiment.</span>
+<span class="sd">    :param experiment_name:     Name of the experiment.</span>
+<span class="sd">    :param ckpt_root_dir:       Local root directory path where all experiment logging directories will</span>
+<span class="sd">                                reside. When none is give, it is assumed that pkg_resources.resource_filename(&#39;checkpoints&#39;, &quot;&quot;)</span>
+<span class="sd">                                exists and will be used.</span>
+<span class="sd">    :return:                    checkpoints_dir_path</span>
+<span class="sd">    &quot;&quot;&quot;</span>
+    <span class="k">if</span> <span class="n">ckpt_root_dir</span><span class="p">:</span>
+        <span class="k">return</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">ckpt_root_dir</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">)</span>
+    <span class="k">elif</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">environment_config</span><span class="o">.</span><span class="n">PKG_CHECKPOINTS_DIR</span><span class="p">):</span>
+        <span class="k">return</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">environment_config</span><span class="o">.</span><span class="n">PKG_CHECKPOINTS_DIR</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">)</span>
+    <span class="k">else</span><span class="p">:</span>
+        <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Illegal checkpoints directory: pass ckpt_root_dir that exists, or add &#39;checkpoints&#39; to resources.&quot;</span><span class="p">)</span>
+
+
+<span class="k">def</span> <span class="nf">get_ckpt_local_path</span><span class="p">(</span><span class="n">source_ckpt_folder_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">ckpt_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">external_checkpoint_path</span><span class
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Gets the local path to the checkpoint file, which will be:</span>
 <span class="sd">    Gets the local path to the checkpoint file, which will be:</span>
 <span class="sd">        - By default: YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name.</span>
 <span class="sd">        - By default: YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name.</span>
@@ -110,37 +135,20 @@
 <span class="sd">        - external_checkpoint_path when external_checkpoint_path != None</span>
 <span class="sd">        - external_checkpoint_path when external_checkpoint_path != None</span>
 
 
 <span class="sd">    @param source_ckpt_folder_name: The folder where the checkpoint is saved. When set to None- uses the experiment_name.</span>
 <span class="sd">    @param source_ckpt_folder_name: The folder where the checkpoint is saved. When set to None- uses the experiment_name.</span>
-<span class="sd">    @param experiment_name: experiment name attr in sg_model</span>
+<span class="sd">    @param experiment_name: experiment name attr in trainer</span>
 <span class="sd">    @param ckpt_name: checkpoint filename</span>
 <span class="sd">    @param ckpt_name: checkpoint filename</span>
-<span class="sd">    @param model_checkpoints_location: S3, local ot URL</span>
 <span class="sd">    @param external_checkpoint_path: full path to checkpoint file (that might be located outside of super_gradients/checkpoints directory)</span>
 <span class="sd">    @param external_checkpoint_path: full path to checkpoint file (that might be located outside of super_gradients/checkpoints directory)</span>
-<span class="sd">    @param overwrite_local_checkpoint: whether to overwrite the checkpoint file with the same name when downloading from S3.</span>
-<span class="sd">    @param load_weights_only: whether to load the network&#39;s state dict only.</span>
 <span class="sd">    @return:</span>
 <span class="sd">    @return:</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
-    <span class="n">source_ckpt_folder_name</span> <span class="o">=</span> <span class="n">source_ckpt_folder_name</span> <span class="ow">or</span> <span class="n">experiment_name</span>
-    <span class="k">if</span> <span class="n">model_checkpoints_location</span> <span class="o">==</span> <span class="s1">&#39;local&#39;</span><span class="p">:</span>
-        <span class="n">ckpt_local_path</span> <span class="o">=</span> <span class="n">external_checkpoint_path</span> <span class="ow">or</span> <span class="n">pkg_resources</span><span class="o">.</span><span class="n">resource_filename</span><span class="p">(</span><span class="s1">&#39;checkpoints&#39;</span><span class="p">,</span> <span class="n">source_ckpt_folder_name</span> <span class="o">+</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span clas
-
-    <span class="c1"># COPY THE DATA FROM &#39;S3&#39;/&#39;URL&#39; INTO A LOCAL DIRECTORY</span>
-    <span class="k">elif</span> <span class="n">model_checkpoints_location</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s1">&#39;s3&#39;</span><span class="p">)</span> <span class="ow">or</span> <span class="n">model_checkpoints_location</span> <span class="o">==</span> <span class="s1">&#39;url&#39;</span><span class="p">:</span>
-        <span class="c1"># COPY REMOTE DATA TO A LOCAL DIRECTORY AND GET THAT DIRECTORYs NAME</span>
-        <span class="n">ckpt_local_path</span> <span class="o">=</span> <span class="n">copy_ckpt_to_local_folder</span><span class="p">(</span><span class="n">local_ckpt_destination_dir</span><span class="o">=</span><span class="n">experiment_name</span><span class="p">,</span>
-                                                    <span class="n">ckpt_filename</span><span class="o">=</span><span class="n">ckpt_name</span><span class="p">,</span>
-                                                    <span class="n">remote_ckpt_source_dir</span><span class="o">=</span><span class="n">source_ckpt_folder_name</span><span class="p">,</span>
-                                                    <span class="n">path_src</span><span class="o">=</span><span class="n">model_checkpoints_location</span><span class="p">,</span>
-                                                    <span class="n">overwrite_local_ckpt</span><span class="o">=</span><span class="n">overwrite_local_checkpoint</span><span class="p">,</span>
-                                                    <span class="n">load_weights_only</span><span class="o">=</span><span class="n">load_weights_only</span><span class="p">)</span>
-
+    <span class="k">if</span> <span class="n">external_checkpoint_path</span><span class="p">:</span>
+        <span class="k">return</span> <span class="n">external_checkpoint_path</span>
     <span class="k">else</span><span class="p">:</span>
     <span class="k">else</span><span class="p">:</span>
-        <span class="c1"># ERROR IN USER CODE FLOW - THIS WILL EVENTUALLY RAISE AN EXCEPTION</span>
-        <span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span>
-            <span class="s1">&#39;model_checkpoints_data_source: &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">model_checkpoints_location</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;not supported&#39;</span><span class="p">)</span>
+        <span class="n">checkpoints_folder_name</span> <span class="o">=</span> <span class="n">source_ckpt_folder_name</span> <span class="ow">or</span> <span class="n">experiment_name</span>
+        <span class="n">checkpoints_dir_path</span> <span class="o">=</span> <span class="n">get_checkpoints_dir_path</span><span class="p">(</span><span class="n">checkpoints_folder_name</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">checkpoints_dir_path</span><span class="p">,</span> <span class="n">ckpt_name</span><span class="p">)</span>
 
 
-    <span class="k">return</span> <span class="n">ckpt_local_path</span></div>
 
 
-
-<div class="viewcode-block" id="adaptive_load_state_dict"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.adaptive_load_state_dict">[docs]</a><span class="k">def</span> <span class="nf">adaptive_load_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p"
+<span class="k">def</span> <span class="nf">adaptive_load_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">state_dict</span><span class="p">:</span> <span class="nb">dict</span><span class="p">,</span> <span class="n">strict</span><span class="p">:</span> <span class="nb">str</span><s
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Adaptively loads state_dict to net, by adapting the state_dict to net&#39;s layer names first.</span>
 <span class="sd">    Adaptively loads state_dict to net, by adapting the state_dict to net&#39;s layer names first.</span>
 
 
@@ -150,19 +158,24 @@
 <span class="sd">    @return:</span>
 <span class="sd">    @return:</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
     <span class="k">try</span><span class="p">:</span>
     <span class="k">try</span><span class="p">:</span>
-        <span class="n">net</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">state_dict</span><span class="p">[</span><span class="s1">&#39;net&#39;</span><span class="p">],</span> <span class="n">strict</span><span class="o">=</span><span class="n">strict</span><span class="p">)</span>
+        <span class="n">net</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">state_dict</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">]</span> <span class="k">if</span> <span class="s2">&quot;net&quot;</span> <span class="ow">in</span> <span class="n">state_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> <span class="k">else</span> <span class="n">state_dict<
     <span class="k">except</span> <span class="p">(</span><span class="ne">RuntimeError</span><span class="p">,</span> <span class="ne">ValueError</span><span class="p">,</span> <span class="ne">KeyError</span><span class="p">)</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>
     <span class="k">except</span> <span class="p">(</span><span class="ne">RuntimeError</span><span class="p">,</span> <span class="ne">ValueError</span><span class="p">,</span> <span class="ne">KeyError</span><span class="p">)</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>
-        <span class="k">if</span> <span class="n">strict</span> <span class="o">==</span> <span class="s1">&#39;no_key_matching&#39;</span><span class="p">:</span>
+        <span class="k">if</span> <span class="n">strict</span> <span class="o">==</span> <span class="s2">&quot;no_key_matching&quot;</span><span class="p">:</span>
             <span class="n">adapted_state_dict</span> <span class="o">=</span> <span class="n">adapt_state_dict_to_fit_model_layer_names</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">state_dict</span><span class="p">)</span>
             <span class="n">adapted_state_dict</span> <span class="o">=</span> <span class="n">adapt_state_dict_to_fit_model_layer_names</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">state_dict</span><span class="p">)</span>
-            <span class="n">net</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">adapted_state_dict</span><span class="p">[</span><span class="s1">&#39;net&#39;</span><span class="p">],</span> <span class="n">strict</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
+            <span class="n">net</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">adapted_state_dict</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">],</span> <span class="n">strict</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
         <span class="k">else</span><span class="p">:</span>
         <span class="k">else</span><span class="p">:</span>
-            <span class="n">raise_informative_runtime_error</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">state_dict</span><span class="p">,</span> <span class="n">ex</span><span class="p">)</span></div>
-
-
-<span class="nd">@explicit_params_validation</span><span class="p">(</span><span class="n">validation_type</span><span class="o">=</span><span class="s1">&#39;None&#39;</span><span class="p">)</span>
-<span class="k">def</span> <span class="nf">copy_ckpt_to_local_folder</span><span class="p">(</span><span class="n">local_ckpt_destination_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">ckpt_filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">remote_ckpt_source_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span cla
-                              <span class="n">path_src</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;local&#39;</span><span class="p">,</span> <span class="n">overwrite_local_ckpt</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
-                              <span class="n">load_weights_only</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
+            <span class="n">raise_informative_runtime_error</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">state_dict</span><span class="p">,</span> <span class="n">ex</span><span class="p">)</span>
+
+
+<span class="nd">@explicit_params_validation</span><span class="p">(</span><span class="n">validation_type</span><span class="o">=</span><span class="s2">&quot;None&quot;</span><span class="p">)</span>
+<span class="k">def</span> <span class="nf">copy_ckpt_to_local_folder</span><span class="p">(</span>
+    <span class="n">local_ckpt_destination_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+    <span class="n">ckpt_filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
+    <span class="n">remote_ckpt_source_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
+    <span class="n">path_src</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;local&quot;</span><span class="p">,</span>
+    <span class="n">overwrite_local_ckpt</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+    <span class="n">load_weights_only</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
+<span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Copy the checkpoint from any supported source to a local destination path</span>
 <span class="sd">    Copy the checkpoint from any supported source to a local destination path</span>
 <span class="sd">        :param local_ckpt_destination_dir:  destination where the checkpoint will be saved to</span>
 <span class="sd">        :param local_ckpt_destination_dir:  destination where the checkpoint will be saved to</span>
@@ -181,27 +194,31 @@
     <span class="k">if</span> <span class="ow">not</span> <span class="n">overwrite_local_ckpt</span><span class="p">:</span>
     <span class="k">if</span> <span class="ow">not</span> <span class="n">overwrite_local_ckpt</span><span class="p">:</span>
         <span class="c1"># CREATE A TEMP FOLDER TO SAVE THE CHECKPOINT TO</span>
         <span class="c1"># CREATE A TEMP FOLDER TO SAVE THE CHECKPOINT TO</span>
         <span class="n">download_ckpt_destination_dir</span> <span class="o">=</span> <span class="n">tempfile</span><span class="o">.</span><span class="n">gettempdir</span><span class="p">()</span>
         <span class="n">download_ckpt_destination_dir</span> <span class="o">=</span> <span class="n">tempfile</span><span class="o">.</span><span class="n">gettempdir</span><span class="p">()</span>
-        <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;PLEASE NOTICE - YOU ARE IMPORTING A REMOTE CHECKPOINT WITH overwrite_local_checkpoint = False &#39;</span>
-              <span class="s1">&#39;-&gt; IT WILL BE REDIRECTED TO A TEMP FOLDER AND DELETED ON MACHINE RESTART&#39;</span><span class="p">)</span>
+        <span class="nb">print</span><span class="p">(</span>
+            <span class="s2">&quot;PLEASE NOTICE - YOU ARE IMPORTING A REMOTE CHECKPOINT WITH overwrite_local_checkpoint = False &quot;</span>
+            <span class="s2">&quot;-&gt; IT WILL BE REDIRECTED TO A TEMP FOLDER AND DELETED ON MACHINE RESTART&quot;</span>
+        <span class="p">)</span>
     <span class="k">else</span><span class="p">:</span>
     <span class="k">else</span><span class="p">:</span>
         <span class="c1"># SAVE THE CHECKPOINT TO MODEL&#39;s FOLDER</span>
         <span class="c1"># SAVE THE CHECKPOINT TO MODEL&#39;s FOLDER</span>
-        <span class="n">download_ckpt_destination_dir</span> <span class="o">=</span> <span class="n">pkg_resources</span><span class="o">.</span><span class="n">resource_filename</span><span class="p">(</span><span class="s1">&#39;checkpoints&#39;</span><span class="p">,</span> <span class="n">local_ckpt_destination_dir</span><span class="p">)</span>
+        <span class="n">download_ckpt_destination_dir</span> <span class="o">=</span> <span class="n">pkg_resources</span><span class="o">.</span><span class="n">resource_filename</span><span class="p">(</span><span class="s2">&quot;checkpoints&quot;</span><span class="p">,</span> <span class="n">local_ckpt_destination_dir</span><span class="p">)</span>
 
 
-    <span class="k">if</span> <span class="n">path_src</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s1">&#39;s3&#39;</span><span class="p">):</span>
+    <span class="k">if</span> <span class="n">path_src</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">&quot;s3&quot;</span><span class="p">):</span>
         <span class="n">model_checkpoints_data_interface</span> <span class="o">=</span> <span class="n">ADNNModelRepositoryDataInterfaces</span><span class="p">(</span><span class="n">data_connection_location</span><span class="o">=</span><span class="n">path_src</span><span class="p">)</span>
         <span class="n">model_checkpoints_data_interface</span> <span class="o">=</span> <span class="n">ADNNModelRepositoryDataInterfaces</span><span class="p">(</span><span class="n">data_connection_location</span><span class="o">=</span><span class="n">path_src</span><span class="p">)</span>
         <span class="c1"># DOWNLOAD THE FILE FROM S3 TO THE DESTINATION FOLDER</span>
         <span class="c1"># DOWNLOAD THE FILE FROM S3 TO THE DESTINATION FOLDER</span>
         <span class="n">ckpt_file_full_local_path</span> <span class="o">=</span> <span class="n">model_checkpoints_data_interface</span><span class="o">.</span><span class="n">load_remote_checkpoints_file</span><span class="p">(</span>
         <span class="n">ckpt_file_full_local_path</span> <span class="o">=</span> <span class="n">model_checkpoints_data_interface</span><span class="o">.</span><span class="n">load_remote_checkpoints_file</span><span class="p">(</span>
             <span class="n">ckpt_source_remote_dir</span><span class="o">=</span><span class="n">remote_ckpt_source_dir</span><span class="p">,</span>
             <span class="n">ckpt_source_remote_dir</span><span class="o">=</span><span class="n">remote_ckpt_source_dir</span><span class="p">,</span>
             <span class="n">ckpt_destination_local_dir</span><span class="o">=</span><span class="n">download_ckpt_destination_dir</span><span class="p">,</span>
             <span class="n">ckpt_destination_local_dir</span><span class="o">=</span><span class="n">download_ckpt_destination_dir</span><span class="p">,</span>
             <span class="n">ckpt_file_name</span><span class="o">=</span><span class="n">ckpt_filename</span><span class="p">,</span>
             <span class="n">ckpt_file_name</span><span class="o">=</span><span class="n">ckpt_filename</span><span class="p">,</span>
-            <span class="n">overwrite_local_checkpoints_file</span><span class="o">=</span><span class="n">overwrite_local_ckpt</span><span class="p">)</span>
+            <span class="n">overwrite_local_checkpoints_file</span><span class="o">=</span><span class="n">overwrite_local_ckpt</span><span class="p">,</span>
+        <span class="p">)</span>
 
 
         <span class="k">if</span> <span class="ow">not</span> <span class="n">load_weights_only</span><span class="p">:</span>
         <span class="k">if</span> <span class="ow">not</span> <span class="n">load_weights_only</span><span class="p">:</span>
             <span class="c1"># COPY LOG FILES FROM THE REMOTE DIRECTORY TO THE LOCAL ONE ONLY IF LOADING THE CURRENT MODELs CKPT</span>
             <span class="c1"># COPY LOG FILES FROM THE REMOTE DIRECTORY TO THE LOCAL ONE ONLY IF LOADING THE CURRENT MODELs CKPT</span>
-            <span class="n">model_checkpoints_data_interface</span><span class="o">.</span><span class="n">load_all_remote_log_files</span><span class="p">(</span><span class="n">model_name</span><span class="o">=</span><span class="n">remote_ckpt_source_dir</span><span class="p">,</span>
-                                                                       <span class="n">model_checkpoint_local_dir</span><span class="o">=</span><span class="n">download_ckpt_destination_dir</span><span class="p">)</span>
+            <span class="n">model_checkpoints_data_interface</span><span class="o">.</span><span class="n">load_all_remote_log_files</span><span class="p">(</span>
+                <span class="n">model_name</span><span class="o">=</span><span class="n">remote_ckpt_source_dir</span><span class="p">,</span> <span class="n">model_checkpoint_local_dir</span><span class="o">=</span><span class="n">download_ckpt_destination_dir</span>
+            <span class="p">)</span>
 
 
-    <span class="k">if</span> <span class="n">path_src</span> <span class="o">==</span> <span class="s1">&#39;url&#39;</span><span class="p">:</span>
+    <span class="k">if</span> <span class="n">path_src</span> <span class="o">==</span> <span class="s2">&quot;url&quot;</span><span class="p">:</span>
         <span class="n">ckpt_file_full_local_path</span> <span class="o">=</span> <span class="n">download_ckpt_destination_dir</span> <span class="o">+</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">sep</span> <span class="o">+</span> <span class="n">ckpt_filename</span>
         <span class="n">ckpt_file_full_local_path</span> <span class="o">=</span> <span class="n">download_ckpt_destination_dir</span> <span class="o">+</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">sep</span> <span class="o">+</span> <span class="n">ckpt_filename</span>
         <span class="c1"># DOWNLOAD THE FILE FROM URL TO THE DESTINATION FOLDER</span>
         <span class="c1"># DOWNLOAD THE FILE FROM URL TO THE DESTINATION FOLDER</span>
         <span class="n">download_url_to_file</span><span class="p">(</span><span class="n">remote_ckpt_source_dir</span><span class="p">,</span> <span class="n">ckpt_file_full_local_path</span><span class="p">,</span> <span class="n">progress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
         <span class="n">download_url_to_file</span><span class="p">(</span><span class="n">remote_ckpt_source_dir</span><span class="p">,</span> <span class="n">ckpt_file_full_local_path</span><span class="p">,</span> <span class="n">progress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
@@ -209,20 +226,19 @@
     <span class="k">return</span> <span class="n">ckpt_file_full_local_path</span>
     <span class="k">return</span> <span class="n">ckpt_file_full_local_path</span>
 
 
 
 
-<div class="viewcode-block" id="read_ckpt_state_dict"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.read_ckpt_state_dict">[docs]</a><span class="k">def</span> <span class="nf">read_ckpt_state_dict</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cpu&quot;</span><span c
+<span class="k">def</span> <span class="nf">read_ckpt_state_dict</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cpu&quot;</span><span class="p">):</span>
     <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">):</span>
     <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">):</span>
-        <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">&#39;Incorrect Checkpoint path&#39;</span><span class="p">)</span>
+        <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Incorrect Checkpoint path&quot;</span><span class="p">)</span>
 
 
     <span class="k">if</span> <span class="n">device</span> <span class="o">==</span> <span class="s2">&quot;cuda&quot;</span><span class="p">:</span>
     <span class="k">if</span> <span class="n">device</span> <span class="o">==</span> <span class="s2">&quot;cuda&quot;</span><span class="p">:</span>
         <span class="n">state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">)</span>
         <span class="n">state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">)</span>
 
 
     <span class="k">else</span><span class="p">:</span>
     <span class="k">else</span><span class="p">:</span>
         <span class="n">state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">,</span> <span class="n">map_location</span><span class="o">=</span><span class="k">lambda</span> <span class="n">storage</span><span class="p">,</span> <span class="n">loc</span><span class="p">:</span> <span class="n">storage</span><span class="p">)</span>
         <span class="n">state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">,</span> <span class="n">map_location</span><span class="o">=</span><span class="k">lambda</span> <span class="n">storage</span><span class="p">,</span> <span class="n">loc</span><span class="p">:</span> <span class="n">storage</span><span class="p">)</span>
-    <span class="k">return</span> <span class="n">state_dict</span></div>
+    <span class="k">return</span> <span class="n">state_dict</span>
 
 
 
 
-<div class="viewcode-block" id="adapt_state_dict_to_fit_model_layer_names"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.adapt_state_dict_to_fit_model_layer_names">[docs]</a><span class="k">def</span> <span class="nf">adapt_state_dict_to_fit_model_layer_names</span><span class="p">(</span><span class="n">model_state_dict</span><span class="p">:</span> <span class="nb">dict</span><span class="p">,</span> <span class="n">source_ckpt<
-                                              <span class="n">exclude</span><span class="p">:</span> <span class="nb">list</span> <span class="o">=</span> <span class="p">[],</span> <span class="n">solver</span><span class="p">:</span> <span class="n">callable</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
+<div class="viewcode-block" id="adapt_state_dict_to_fit_model_layer_names"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.adapt_state_dict_to_fit_model_layer_names">[docs]</a><span class="k">def</span> <span class="nf">adapt_state_dict_to_fit_model_layer_names</span><span class="p">(</span><span class="n">model_state_dict</span><span class="p">:</span> <span class="nb">dict</span><span class="p">,</span> <span class="n">source_ckpt</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Given a model state dict and source checkpoints, the method tries to correct the keys in the model_state_dict to fit</span>
 <span class="sd">    Given a model state dict and source checkpoints, the method tries to correct the keys in the model_state_dict to fit</span>
 <span class="sd">    the ckpt in order to properly load the weights into the model. If unsuccessful - returns None</span>
 <span class="sd">    the ckpt in order to properly load the weights into the model. If unsuccessful - returns None</span>
@@ -233,39 +249,41 @@
 <span class="sd">                                        that returns a desired weight for ckpt_val.</span>
 <span class="sd">                                        that returns a desired weight for ckpt_val.</span>
 <span class="sd">        :return: renamed checkpoint dict (if possible)</span>
 <span class="sd">        :return: renamed checkpoint dict (if possible)</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
-    <span class="k">if</span> <span class="s1">&#39;net&#39;</span> <span class="ow">in</span> <span class="n">source_ckpt</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
-        <span class="n">source_ckpt</span> <span class="o">=</span> <span class="n">source_ckpt</span><span class="p">[</span><span class="s1">&#39;net&#39;</span><span class="p">]</span>
+    <span class="k">if</span> <span class="s2">&quot;net&quot;</span> <span class="ow">in</span> <span class="n">source_ckpt</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
+        <span class="n">source_ckpt</span> <span class="o">=</span> <span class="n">source_ckpt</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">]</span>
     <span class="n">model_state_dict_excluded</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">model_state_dict</span><span class="o">.</span><span class="n">items</span><span class="p">()</span> <span class="k">if</span> <span class="ow">not</span> <span class="nb">any</sp
     <span class="n">model_state_dict_excluded</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">model_state_dict</span><span class="o">.</span><span class="n">items</span><span class="p">()</span> <span class="k">if</span> <span class="ow">not</span> <span class="nb">any</sp
     <span class="n">new_ckpt_dict</span> <span class="o">=</span> <span class="p">{}</span>
     <span class="n">new_ckpt_dict</span> <span class="o">=</span> <span class="p">{}</span>
     <span class="k">for</span> <span class="p">(</span><span class="n">ckpt_key</span><span class="p">,</span> <span class="n">ckpt_val</span><span class="p">),</span> <span class="p">(</span><span class="n">model_key</span><span class="p">,</span> <span class="n">model_val</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">source_ckpt</span><span class="o">.</span><span class="n">items</span><span class="p">(),</span> <s
     <span class="k">for</span> <span class="p">(</span><span class="n">ckpt_key</span><span class="p">,</span> <span class="n">ckpt_val</span><span class="p">),</span> <span class="p">(</span><span class="n">model_key</span><span class="p">,</span> <span class="n">model_val</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">source_ckpt</span><span class="o">.</span><span class="n">items</span><span class="p">(),</span> <s
         <span class="k">if</span> <span class="n">solver</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">solver</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
             <span class="n">ckpt_val</span> <span class="o">=</span> <span class="n">solver</span><span class="p">(</span><span class="n">ckpt_key</span><span class="p">,</span> <span class="n">ckpt_val</span><span class="p">,</span> <span class="n">model_key</span><span class="p">,</span> <span class="n">model_val</span><span class="p">)</span>
             <span class="n">ckpt_val</span> <span class="o">=</span> <span class="n">solver</span><span class="p">(</span><span class="n">ckpt_key</span><span class="p">,</span> <span class="n">ckpt_val</span><span class="p">,</span> <span class="n">model_key</span><span class="p">,</span> <span class="n">model_val</span><span class="p">)</span>
         <span class="k">if</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="n">model_val</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="n">model_val</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
-            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;ckpt layer </span><span class="si">{</span><span class="n">ckpt_key</span><span class="si">}</span><span class="s1"> with shape </span><span class="si">{</span><span class="n">ckpt_val</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s1"> does not match </span><span class="si">{</span><span class="n">mode
-                             <span class="sa">f</span><span class="s1">&#39; with shape </span><span class="si">{</span><span class="n">model_val</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s1"> in the model&#39;</span><span class="p">)</span>
+            <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;ckpt layer </span><span class="si">{</span><span class="n">ckpt_key</span><span class="si">}</span><span class="s2"> with shape </span><span class="si">{</span><span class="n">ckpt_val</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2"> does not match </span><span class="si">{</span><span class="n">mod
         <span class="n">new_ckpt_dict</span><span class="p">[</span><span class="n">model_key</span><span class="p">]</span> <span class="o">=</span> <span class="n">ckpt_val</span>
         <span class="n">new_ckpt_dict</span><span class="p">[</span><span class="n">model_key</span><span class="p">]</span> <span class="o">=</span> <span class="n">ckpt_val</span>
-    <span class="k">return</span> <span class="p">{</span><span class="s1">&#39;net&#39;</span><span class="p">:</span> <span class="n">new_ckpt_dict</span><span class="p">}</span></div>
+    <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;net&quot;</span><span class="p">:</span> <span class="n">new_ckpt_dict</span><span class="p">}</span></div>
 
 
 
 
-<div class="viewcode-block" id="raise_informative_runtime_error"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.raise_informative_runtime_error">[docs]</a><span class="k">def</span> <span class="nf">raise_informative_runtime_error</span><span class="p">(</span><span class="n">state_dict</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">,</span> <span class="n">exception_msg</span><span class="p">):</spa
+<div class="viewcode-block" id="raise_informative_runtime_error"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.raise_informative_runtime_error">[docs]</a><span class="k">def</span> <span class="nf">raise_informative_runtime_error</span><span class="p">(</span><span class="n">state_dict</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">,</span> <span class="n">exception_msg</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Given a model state dict and source checkpoints, the method calls &quot;adapt_state_dict_to_fit_model_layer_names&quot;</span>
 <span class="sd">    Given a model state dict and source checkpoints, the method calls &quot;adapt_state_dict_to_fit_model_layer_names&quot;</span>
 <span class="sd">    and enhances the exception_msg if loading the checkpoint_dict via the conversion method is possible</span>
 <span class="sd">    and enhances the exception_msg if loading the checkpoint_dict via the conversion method is possible</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
     <span class="k">try</span><span class="p">:</span>
     <span class="k">try</span><span class="p">:</span>
         <span class="n">new_ckpt_dict</span> <span class="o">=</span> <span class="n">adapt_state_dict_to_fit_model_layer_names</span><span class="p">(</span><span class="n">state_dict</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">)</span>
         <span class="n">new_ckpt_dict</span> <span class="o">=</span> <span class="n">adapt_state_dict_to_fit_model_layer_names</span><span class="p">(</span><span class="n">state_dict</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">)</span>
-        <span class="n">temp_file</span> <span class="o">=</span> <span class="n">tempfile</span><span class="o">.</span><span class="n">NamedTemporaryFile</span><span class="p">()</span><span class="o">.</span><span class="n">name</span> <span class="o">+</span> <span class="s1">&#39;.pt&#39;</span>
+        <span class="n">temp_file</span> <span class="o">=</span> <span class="n">tempfile</span><span class="o">.</span><span class="n">NamedTemporaryFile</span><span class="p">()</span><span class="o">.</span><span class="n">name</span> <span class="o">+</span> <span class="s2">&quot;.pt&quot;</span>
         <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">new_ckpt_dict</span><span class="p">,</span> <span class="n">temp_file</span><span class="p">)</span>
         <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">new_ckpt_dict</span><span class="p">,</span> <span class="n">temp_file</span><span class="p">)</span>
-        <span class="n">exception_msg</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="si">{</span><span class="s1">&#39;=&#39;</span> <span class="o">*</span> <span class="mi">200</span><span class="si">}</span><span class="se">\n</span><span class="si">{</span><span class="nb">str</span><span class="p">(</span><span class="n">exception_msg</span><span class="p">)</span><span class="si">}</span><span class="s2"> </spa
-                        <span class="sa">f</span><span class="s2">&quot;model_layer_names method</span><span class="se">\n</span><span class="s2">a converted checkpoint file was saved in the path </span><span class="si">{</span><span class="n">temp_file</span><span class="si">}</span><span class="se">\n</span><span class="si">{</span><span class="s1">&#39;=&#39;</span> <span class="o">*</span> <span class="mi">200</span><span class="si">}</span><span class="s2">&quot;</span>
+        <span class="n">exception_msg</span> <span class="o">=</span> <span class="p">(</span>
+            <span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="si">{</span><span class="s1">&#39;=&#39;</span> <span class="o">*</span> <span class="mi">200</span><span class="si">}</span><span class="se">\n</span><span class="si">{</span><span class="nb">str</span><span class="p">(</span><span class="n">exception_msg</span><span class="p">)</span><span class="si">}</span><span class="s2"> </span><span class="se">\n</span><span class="s2">convert ckpt 
+            <span class="sa">f</span><span class="s2">&quot;model_layer_names method</span><span class="se">\n</span><span class="s2">a converted checkpoint file was saved in the path </span><span class="si">{</span><span class="n">temp_file</span><span class="si">}</span><span class="se">\n</span><span class="si">{</span><span class="s1">&#39;=&#39;</span> <span class="o">*</span> <span class="mi">200</span><span class="si">}</span><span class="s2">&quot;</span>
+        <span class="p">)</span>
     <span class="k">except</span> <span class="ne">ValueError</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>  <span class="c1"># IN CASE adapt_state_dict_to_fit_model_layer_names WAS UNSUCCESSFUL</span>
     <span class="k">except</span> <span class="ne">ValueError</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>  <span class="c1"># IN CASE adapt_state_dict_to_fit_model_layer_names WAS UNSUCCESSFUL</span>
         <span class="n">exception_msg</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="si">{</span><span class="s1">&#39;=&#39;</span> <span class="o">*</span> <span class="mi">200</span><span class="si">}</span><span class="s2"> </span><span class="se">\n</span><span class="s2">The checkpoint and model shapes do no fit, e.g.: </span><span class="si">{</span><span class="n">ex</span><span class="si">}</span><span class
         <span class="n">exception_msg</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="si">{</span><span class="s1">&#39;=&#39;</span> <span class="o">*</span> <span class="mi">200</span><span class="si">}</span><span class="s2"> </span><span class="se">\n</span><span class="s2">The checkpoint and model shapes do no fit, e.g.: </span><span class="si">{</span><span class="n">ex</span><span class="si">}</span><span class
     <span class="k">finally</span><span class="p">:</span>
     <span class="k">finally</span><span class="p">:</span>
         <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="n">exception_msg</span><span class="p">)</span></div>
         <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="n">exception_msg</span><span class="p">)</span></div>
 
 
 
 
-<div class="viewcode-block" id="load_checkpoint_to_model"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.load_checkpoint_to_model">[docs]</a><span class="k">def</span> <span class="nf">load_checkpoint_to_model</span><span class="p">(</span><span class="n">ckpt_local_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">load_backbone</span><span class="p">:</span> <span class="nb">boo
-                             <span class="n">load_weights_only</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> <span class="n">load_ema_as_net</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
+<span class="k">def</span> <span class="nf">load_checkpoint_to_model</span><span class="p">(</span>
+    <span class="n">ckpt_local_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">load_backbone</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> <span class="n">net</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">strict</span><span class="p">:</span> <span 
+<span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Loads the state dict in ckpt_local_path to net and returns the checkpoint&#39;s state dict.</span>
 <span class="sd">    Loads the state dict in ckpt_local_path to net and returns the checkpoint&#39;s state dict.</span>
 
 
@@ -278,35 +296,39 @@
 <span class="sd">    @return:</span>
 <span class="sd">    @return:</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
     <span class="k">if</span> <span class="n">ckpt_local_path</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">ckpt_local_path</span><span class="p">):</span>
     <span class="k">if</span> <span class="n">ckpt_local_path</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">ckpt_local_path</span><span class="p">):</span>
-        <span class="n">error_msg</span> <span class="o">=</span> <span class="s1">&#39;Error - loading Model Checkpoint: Path </span><span class="si">{}</span><span class="s1"> does not exist&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">ckpt_local_path</span><span class="p">)</span>
+        <span class="n">error_msg</span> <span class="o">=</span> <span class="s2">&quot;Error - loading Model Checkpoint: Path </span><span class="si">{}</span><span class="s2"> does not exist&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">ckpt_local_path</span><span class="p">)</span>
         <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="n">error_msg</span><span class="p">)</span>
         <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="n">error_msg</span><span class="p">)</span>
 
 
-    <span class="k">if</span> <span class="n">load_backbone</span> <span class="ow">and</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="p">,</span> <span class="s1">&#39;backbone&#39;</span><span class="p">):</span>
+    <span class="k">if</span> <span class="n">load_backbone</span> <span class="ow">and</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="s2">&quot;backbone&quot;</span><span class="p">):</span>
         <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;No backbone attribute in net - Can&#39;t load backbone weights&quot;</span><span class="p">)</span>
         <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;No backbone attribute in net - Can&#39;t load backbone weights&quot;</span><span class="p">)</span>
 
 
     <span class="c1"># LOAD THE LOCAL CHECKPOINT PATH INTO A state_dict OBJECT</span>
     <span class="c1"># LOAD THE LOCAL CHECKPOINT PATH INTO A state_dict OBJECT</span>
     <span class="n">checkpoint</span> <span class="o">=</span> <span class="n">read_ckpt_state_dict</span><span class="p">(</span><span class="n">ckpt_path</span><span class="o">=</span><span class="n">ckpt_local_path</span><span class="p">)</span>
     <span class="n">checkpoint</span> <span class="o">=</span> <span class="n">read_ckpt_state_dict</span><span class="p">(</span><span class="n">ckpt_path</span><span class="o">=</span><span class="n">ckpt_local_path</span><span class="p">)</span>
 
 
     <span class="k">if</span> <span class="n">load_ema_as_net</span><span class="p">:</span>
     <span class="k">if</span> <span class="n">load_ema_as_net</span><span class="p">:</span>
-        <span class="k">if</span> <span class="s1">&#39;ema_net&#39;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">checkpoint</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
+        <span class="k">if</span> <span class="s2">&quot;ema_net&quot;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">checkpoint</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
             <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Can&#39;t load ema network- no EMA network stored in checkpoint file&quot;</span><span class="p">)</span>
             <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Can&#39;t load ema network- no EMA network stored in checkpoint file&quot;</span><span class="p">)</span>
         <span class="k">else</span><span class="p">:</span>
         <span class="k">else</span><span class="p">:</span>
-            <span class="n">checkpoint</span><span class="p">[</span><span class="s1">&#39;net&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">checkpoint</span><span class="p">[</span><span class="s1">&#39;ema_net&#39;</span><span class="p">]</span>
+            <span class="n">checkpoint</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">checkpoint</span><span class="p">[</span><span class="s2">&quot;ema_net&quot;</span><span class="p">]</span>
 
 
     <span class="c1"># LOAD THE CHECKPOINTS WEIGHTS TO THE MODEL</span>
     <span class="c1"># LOAD THE CHECKPOINTS WEIGHTS TO THE MODEL</span>
     <span class="k">if</span> <span class="n">load_backbone</span><span class="p">:</span>
     <span class="k">if</span> <span class="n">load_backbone</span><span class="p">:</span>
-        <span class="n">adaptive_load_state_dict</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">backbone</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">,</span> <span class="n">strict</span><span class="p">)</span>
+        <span class="n">adaptive_load_state_dict</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">backbone</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">,</span> <span class="n">strict</span><span class="p">)</span>
     <span class="k">else</span><span class="p">:</span>
     <span class="k">else</span><span class="p">:</span>
         <span class="n">adaptive_load_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">,</span> <span class="n">strict</span><span class="p">)</span>
         <span class="n">adaptive_load_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">,</span> <span class="n">strict</span><span class="p">)</span>
 
 
+    <span class="n">message_suffix</span> <span class="o">=</span> <span class="s2">&quot; checkpoint.&quot;</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">load_ema_as_net</span> <span class="k">else</span> <span class="s2">&quot; EMA checkpoint.&quot;</span>
+    <span class="n">message_model</span> <span class="o">=</span> <span class="s2">&quot;model&quot;</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">load_backbone</span> <span class="k">else</span> <span class="s2">&quot;model&#39;s backbone&quot;</span>
+    <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Successfully loaded &quot;</span> <span class="o">+</span> <span class="n">message_model</span> <span class="o">+</span> <span class="s2">&quot; weights from &quot;</span> <span class="o">+</span> <span class="n">ckpt_local_path</span> <span class="o">+</span> <span class="n">message_suffix</span><span class="p">)</span>
+
     <span class="k">if</span> <span class="n">load_weights_only</span> <span class="ow">or</span> <span class="n">load_backbone</span><span class="p">:</span>
     <span class="k">if</span> <span class="n">load_weights_only</span> <span class="ow">or</span> <span class="n">load_backbone</span><span class="p">:</span>
         <span class="c1"># DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS</span>
         <span class="c1"># DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS</span>
-        <span class="p">[</span><span class="n">checkpoint</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="n">checkpoint</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="k">if</span> <span class="n">key</span> <span class=
+        <span class="p">[</span><span class="n">checkpoint</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="n">checkpoint</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="k">if</span> <span class="n">key</span> <span class=
 
 
-    <span class="k">return</span> <span class="n">checkpoint</span></div>
+    <span class="k">return</span> <span class="n">checkpoint</span>
 
 
 
 
-<div class="viewcode-block" id="MissingPretrainedWeightsException"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.MissingPretrainedWeightsException">[docs]</a><span class="k">class</span> <span class="nc">MissingPretrainedWeightsException</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
+<span class="k">class</span> <span class="nc">MissingPretrainedWeightsException</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;Exception raised by unsupported pretrianed model.</span>
     <span class="sd">&quot;&quot;&quot;Exception raised by unsupported pretrianed model.</span>
 
 
 <span class="sd">    Attributes:</span>
 <span class="sd">    Attributes:</span>
@@ -315,7 +337,7 @@
 
 
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">desc</span><span class="p">):</span>
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">desc</span><span class="p">):</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">message</span> <span class="o">=</span> <span class="s2">&quot;Missing pretrained wights: &quot;</span> <span class="o">+</span> <span class="n">desc</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">message</span> <span class="o">=</span> <span class="s2">&quot;Missing pretrained wights: &quot;</span> <span class="o">+</span> <span class="n">desc</span>
-        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">message</span><span class="p">)</span></div>
+        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">message</span><span class="p">)</span>
 
 
 
 
 <span class="k">def</span> <span class="nf">_yolox_ckpt_solver</span><span class="p">(</span><span class="n">ckpt_key</span><span class="p">,</span> <span class="n">ckpt_val</span><span class="p">,</span> <span class="n">model_key</span><span class="p">,</span> <span class="n">model_val</span><span class="p">):</span>
 <span class="k">def</span> <span class="nf">_yolox_ckpt_solver</span><span class="p">(</span><span class="n">ckpt_key</span><span class="p">,</span> <span class="n">ckpt_val</span><span class="p">,</span> <span class="n">model_key</span><span class="p">,</span> <span class="n">model_val</span><span class="p">):</span>
@@ -323,8 +345,11 @@
 <span class="sd">    Helper method for reshaping old pretrained checkpoint&#39;s focus weights to 6x6 conv weights.</span>
 <span class="sd">    Helper method for reshaping old pretrained checkpoint&#39;s focus weights to 6x6 conv weights.</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 
 
-    <span class="k">if</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="n">model_val</span><span class="o">.</span><span class="n">shape</span> <span class="ow">and</span> <span class="n">ckpt_key</span> <span class="o">==</span> <span class="s1">&#39;module._backbone._modules_list.0.conv.conv.weight&#39;</span> <span class="ow">and</span> \
-            <span class="n">model_key</span> <span class="o">==</span> <span class="s1">&#39;_backbone._modules_list.0.conv.weight&#39;</span><span class="p">:</span>
+    <span class="k">if</span> <span class="p">(</span>
+        <span class="n">ckpt_val</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="n">model_val</span><span class="o">.</span><span class="n">shape</span>
+        <span class="ow">and</span> <span class="n">ckpt_key</span> <span class="o">==</span> <span class="s2">&quot;module._backbone._modules_list.0.conv.conv.weight&quot;</span>
+        <span class="ow">and</span> <span class="n">model_key</span> <span class="o">==</span> <span class="s2">&quot;_backbone._modules_list.0.conv.weight&quot;</span>
+    <span class="p">):</span>
         <span class="n">model_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">3</span><span class="p">]</spa
         <span class="n">model_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">3</span><span class="p">]</spa
         <span class="n">model_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">:</sp
         <span class="n">model_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">:</sp
         <span class="n">model_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="mi">6</span><span class="p">:</sp
         <span class="n">model_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="mi">6</span><span class="p">:</sp
@@ -336,7 +361,7 @@
     <span class="k">return</span> <span class="n">replacement</span>
     <span class="k">return</span> <span class="n">replacement</span>
 
 
 
 
-<div class="viewcode-block" id="load_pretrained_weights"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.load_pretrained_weights">[docs]</a><span class="k">def</span> <span class="nf">load_pretrained_weights</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">
+<span class="k">def</span> <span class="nf">load_pretrained_weights</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">architecture</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">pretrained_weights</span><span class="p">:</span> <span class="nb
 
 
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Loads pretrained weights from the MODEL_URLS dictionary to model</span>
 <span class="sd">    Loads pretrained weights from the MODEL_URLS dictionary to model</span>
@@ -345,21 +370,41 @@
 <span class="sd">    @param pretrained_weights: name for the pretrianed weights (i.e imagenet)</span>
 <span class="sd">    @param pretrained_weights: name for the pretrianed weights (i.e imagenet)</span>
 <span class="sd">    @return: None</span>
 <span class="sd">    @return: None</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
-    <span class="n">model_url_key</span> <span class="o">=</span> <span class="n">architecture</span> <span class="o">+</span> <span class="s1">&#39;_&#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">pretrained_weights</span><span class="p">)</span>
+    <span class="n">model_url_key</span> <span class="o">=</span> <span class="n">architecture</span> <span class="o">+</span> <span class="s2">&quot;_&quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">pretrained_weights</span><span class="p">)</span>
     <span class="k">if</span> <span class="n">model_url_key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">MODEL_URLS</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
     <span class="k">if</span> <span class="n">model_url_key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">MODEL_URLS</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
         <span class="k">raise</span> <span class="n">MissingPretrainedWeightsException</span><span class="p">(</span><span class="n">model_url_key</span><span class="p">)</span>
         <span class="k">raise</span> <span class="n">MissingPretrainedWeightsException</span><span class="p">(</span><span class="n">model_url_key</span><span class="p">)</span>
 
 
     <span class="n">url</span> <span class="o">=</span> <span class="n">MODEL_URLS</span><span class="p">[</span><span class="n">model_url_key</span><span class="p">]</span>
     <span class="n">url</span> <span class="o">=</span> <span class="n">MODEL_URLS</span><span class="p">[</span><span class="n">model_url_key</span><span class="p">]</span>
-    <span class="n">unique_filename</span> <span class="o">=</span> <span class="n">url</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;https://deci-pretrained-models.s3.amazonaws.com/&quot;</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s1">&#39;/&#39;</span><span class="p">,</span> <span class="s1">&#39;_&#39;</spa
-    <span class="n">map_location</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="kc">None
+    <span class="n">unique_filename</span> <span class="o">=</span> <span class="n">url</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;https://deci-pretrained-models.s3.amazonaws.com/&quot;</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;/&quot;</span><span class="p">,</span> <span class="s2">&quot;_&quot;<
+    <span class="n">map_location</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;cpu&quot;</span><span class="p">)</span>
     <span class="n">pretrained_state_dict</span> <span class="o">=</span> <span class="n">load_state_dict_from_url</span><span class="p">(</span><span class="n">url</span><span class="o">=</span><span class="n">url</span><span class="p">,</span> <span class="n">map_location</span><span class="o">=</span><span class="n">map_location</span><span class="p">,</span> <span class="n">file_name</span><span class="o">=</span><span class="n">unique_filename</span><span class="p">)</span>
     <span class="n">pretrained_state_dict</span> <span class="o">=</span> <span class="n">load_state_dict_from_url</span><span class="p">(</span><span class="n">url</span><span class="o">=</span><span class="n">url</span><span class="p">,</span> <span class="n">map_location</span><span class="o">=</span><span class="n">map_location</span><span class="p">,</span> <span class="n">file_name</span><span class="o">=</span><span class="n">unique_filename</span><span class="p">)</span>
-    <span class="k">if</span> <span class="s1">&#39;ema_net&#39;</span> <span class="ow">in</span> <span class="n">pretrained_state_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
-        <span class="n">pretrained_state_dict</span><span class="p">[</span><span class="s1">&#39;net&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">pretrained_state_dict</span><span class="p">[</span><span class="s1">&#39;ema_net&#39;</span><span class="p">]</span>
+    <span class="n">_load_weights</span><span class="p">(</span><span class="n">architecture</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">pretrained_state_dict</span><span class="p">)</span>
+
+
+<span class="k">def</span> <span class="nf">_load_weights</span><span class="p">(</span><span class="n">architecture</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">pretrained_state_dict</span><span class="p">):</span>
+    <span class="k">if</span> <span class="s2">&quot;ema_net&quot;</span> <span class="ow">in</span> <span class="n">pretrained_state_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
+        <span class="n">pretrained_state_dict</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">pretrained_state_dict</span><span class="p">[</span><span class="s2">&quot;ema_net&quot;</span><span class="p">]</span>
     <span class="n">solver</span> <span class="o">=</span> <span class="n">_yolox_ckpt_solver</span> <span class="k">if</span> <span class="s2">&quot;yolox&quot;</span> <span class="ow">in</span> <span class="n">architecture</span> <span class="k">else</span> <span class="kc">None</span>
     <span class="n">solver</span> <span class="o">=</span> <span class="n">_yolox_ckpt_solver</span> <span class="k">if</span> <span class="s2">&quot;yolox&quot;</span> <span class="ow">in</span> <span class="n">architecture</span> <span class="k">else</span> <span class="kc">None</span>
-    <span class="n">adapted_pretrained_state_dict</span> <span class="o">=</span> <span class="n">adapt_state_dict_to_fit_model_layer_names</span><span class="p">(</span><span class="n">model_state_dict</span><span class="o">=</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span>
-                                                                              <span class="n">source_ckpt</span><span class="o">=</span><span class="n">pretrained_state_dict</span><span class="p">,</span>
-                                                                              <span class="n">solver</span><span class="o">=</span><span class="n">solver</span><span class="p">)</span>
-    <span class="n">model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">adapted_pretrained_state_dict</span><span class="p">[</span><span class="s1">&#39;net&#39;</span><span class="p">],</span> <span class="n">strict</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></div>
+    <span class="n">adapted_pretrained_state_dict</span> <span class="o">=</span> <span class="n">adapt_state_dict_to_fit_model_layer_names</span><span class="p">(</span>
+        <span class="n">model_state_dict</span><span class="o">=</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">source_ckpt</span><span class="o">=</span><span class="n">pretrained_state_dict</span><span class="p">,</span> <span class="n">solver</span><span class="o">=</span><span class="n">solver</span>
+    <span class="p">)</span>
+    <span class="n">model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">adapted_pretrained_state_dict</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">],</span> <span class="n">strict</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
+
+
+<span class="k">def</span> <span class="nf">load_pretrained_weights_local</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">architecture</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">pretrained_weights</span><span class="p">:</span> <span cla
+
+    <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">    Loads pretrained weights from the MODEL_URLS dictionary to model</span>
+<span class="sd">    @param architecture: name of the model&#39;s architecture</span>
+<span class="sd">    @param model: model to load pretrinaed weights for</span>
+<span class="sd">    @param pretrained_weights: path tp pretrained weights</span>
+<span class="sd">    @return: None</span>
+<span class="sd">    &quot;&quot;&quot;</span>
+
+    <span class="n">map_location</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;cpu&quot;</span><span class="p">)</span>
+
+    <span class="n">pretrained_state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">pretrained_weights</span><span class="p">,</span> <span class="n">map_location</span><span class="o">=</span><span class="n">map_location</span><span class="p">)</span>
+    <span class="n">_load_weights</span><span class="p">(</span><span class="n">architecture</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">pretrained_state_dict</span><span class="p">)</span>
 </pre></div>
 </pre></div>
 
 
            </div>
            </div>
@@ -389,4 +434,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.utils.detection_utils &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.utils.detection_utils</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.utils.detection_utils</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">math</span>
  84. <span class="kn">import</span> <span class="nn">os</span>
  85. <span class="kn">import</span> <span class="nn">pathlib</span>
  86. <span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">ABC</span><span class="p">,</span> <span class="n">abstractmethod</span>
  87. <span class="kn">from</span> <span class="nn">enum</span> <span class="kn">import</span> <span class="n">Enum</span>
  88. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Dict</span>
  89. <span class="kn">import</span> <span class="nn">cv2</span>
  90. <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
  91. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  92. <span class="kn">import</span> <span class="nn">torch</span>
  93. <span class="kn">import</span> <span class="nn">torchvision</span>
  94. <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
  95. <span class="kn">from</span> <span class="nn">torch.utils.data._utils.collate</span> <span class="kn">import</span> <span class="n">default_collate</span>
  96. <span class="kn">from</span> <span class="nn">omegaconf</span> <span class="kn">import</span> <span class="n">ListConfig</span>
  97. <div class="viewcode-block" id="DetectionTargetsFormat"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.DetectionTargetsFormat">[docs]</a><span class="k">class</span> <span class="nc">DetectionTargetsFormat</span><span class="p">(</span><span class="n">Enum</span><span class="p">):</span>
  98. <span class="sd">&quot;&quot;&quot;</span>
  99. <span class="sd"> Enum class for the different detection output formats</span>
  100. <span class="sd"> When NORMALIZED is not specified- the type refers to unnormalized image coordinates (of the bboxes).</span>
  101. <span class="sd"> For example:</span>
  102. <span class="sd"> LABEL_NORMALIZED_XYXY means [class_idx,x1,y1,x2,y2]</span>
  103. <span class="sd"> &quot;&quot;&quot;</span>
  104. <span class="n">LABEL_XYXY</span> <span class="o">=</span> <span class="s2">&quot;LABEL_XYXY&quot;</span>
  105. <span class="n">XYXY_LABEL</span> <span class="o">=</span> <span class="s2">&quot;XYXY_LABEL&quot;</span>
  106. <span class="n">LABEL_NORMALIZED_XYXY</span> <span class="o">=</span> <span class="s2">&quot;LABEL_NORMALIZED_XYXY&quot;</span>
  107. <span class="n">NORMALIZED_XYXY_LABEL</span> <span class="o">=</span> <span class="s2">&quot;NORMALIZED_XYXY_LABEL&quot;</span>
  108. <span class="n">LABEL_CXCYWH</span> <span class="o">=</span> <span class="s2">&quot;LABEL_CXCYWH&quot;</span>
  109. <span class="n">CXCYWH_LABEL</span> <span class="o">=</span> <span class="s2">&quot;CXCYWH_LABEL&quot;</span>
  110. <span class="n">LABEL_NORMALIZED_CXCYWH</span> <span class="o">=</span> <span class="s2">&quot;LABEL_NORMALIZED_CXCYWH&quot;</span>
  111. <span class="n">NORMALIZED_CXCYWH_LABEL</span> <span class="o">=</span> <span class="s2">&quot;NORMALIZED_CXCYWH_LABEL&quot;</span></div>
  112. <div class="viewcode-block" id="get_cls_posx_in_target"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.get_cls_posx_in_target">[docs]</a><span class="k">def</span> <span class="nf">get_cls_posx_in_target</span><span class="p">(</span><span class="n">target_format</span><span class="p">:</span> <span class="n">DetectionTargetsFormat</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
  113. <span class="sd">&quot;&quot;&quot;Get the label of a given target</span>
  114. <span class="sd"> :param target_format: Representation of the target (ex: LABEL_XYXY)</span>
  115. <span class="sd"> :return: Position of the class id in a bbox</span>
  116. <span class="sd"> ex: 0 if bbox of format label_xyxy | -1 if bbox of format xyxy_label</span>
  117. <span class="sd"> &quot;&quot;&quot;</span>
  118. <span class="n">format_split</span> <span class="o">=</span> <span class="n">target_format</span><span class="o">.</span><span class="n">value</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;_&quot;</span><span class="p">)</span>
  119. <span class="k">if</span> <span class="n">format_split</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;LABEL&quot;</span><span class="p">:</span>
  120. <span class="k">return</span> <span class="mi">0</span>
  121. <span class="k">elif</span> <span class="n">format_split</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;LABEL&quot;</span><span class="p">:</span>
  122. <span class="k">return</span> <span class="o">-</span><span class="mi">1</span>
  123. <span class="k">else</span><span class="p">:</span>
  124. <span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;No implementation to find index of LABEL in </span><span class="si">{</span><span class="n">target_format</span><span class="o">.</span><span class="n">value</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span></div>
  125. <span class="k">def</span> <span class="nf">_set_batch_labels_index</span><span class="p">(</span><span class="n">labels_batch</span><span class="p">):</span>
  126. <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">labels</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">labels_batch</span><span class="p">):</span>
  127. <span class="n">labels</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">i</span>
  128. <span class="k">return</span> <span class="n">labels_batch</span>
  129. <div class="viewcode-block" id="convert_xywh_bbox_to_xyxy"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.convert_xywh_bbox_to_xyxy">[docs]</a><span class="k">def</span> <span class="nf">convert_xywh_bbox_to_xyxy</span><span class="p">(</span><span class="n">input_bbox</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
  130. <span class="sd">&quot;&quot;&quot;</span>
  131. <span class="sd"> Converts bounding box format from [x, y, w, h] to [x1, y1, x2, y2]</span>
  132. <span class="sd"> :param input_bbox: input bbox either 2-dimensional (for all boxes of a single image) or 3-dimensional (for</span>
  133. <span class="sd"> boxes of a batch of images)</span>
  134. <span class="sd"> :return: Converted bbox in same dimensions as the original</span>
  135. <span class="sd"> &quot;&quot;&quot;</span>
  136. <span class="n">need_squeeze</span> <span class="o">=</span> <span class="kc">False</span>
  137. <span class="c1"># the input is always processed as a batch. in case it not a batch, it is unsqueezed, process and than squeeze back.</span>
  138. <span class="k">if</span> <span class="n">input_bbox</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">&lt;</span> <span class="mi">3</span><span class="p">:</span>
  139. <span class="n">need_squeeze</span> <span class="o">=</span> <span class="kc">True</span>
  140. <span class="n">input_bbox</span> <span class="o">=</span> <span class="n">input_bbox</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
  141. <span class="n">converted_bbox</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">input_bbox</span><span class="p">)</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">input_bbox</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="k">else</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">input_bbox</span><span class="p">)</span>
  142. <span class="n">converted_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">input_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">input_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span>
  143. <span class="n">converted_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">input_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">input_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span>
  144. <span class="n">converted_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">input_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">input_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span>
  145. <span class="n">converted_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="n">input_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">input_bbox</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span>
  146. <span class="c1"># squeeze back if needed</span>
  147. <span class="k">if</span> <span class="n">need_squeeze</span><span class="p">:</span>
  148. <span class="n">converted_bbox</span> <span class="o">=</span> <span class="n">converted_bbox</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
  149. <span class="k">return</span> <span class="n">converted_bbox</span></div>
  150. <span class="k">def</span> <span class="nf">_iou</span><span class="p">(</span><span class="n">CIoU</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> <span class="n">DIoU</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> <span class="n">GIoU</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> <span class="n">b1_x1</span><span class="p">,</span> <span class="n">b1_x2</span><span class="p">,</span> <span class="n">b1_y1</span><span class="p">,</span> <span class="n">b1_y2</span><span class="p">,</span> <span class="n">b2_x1</span><span class="p">,</span> <span class="n">b2_x2</span><span class="p">,</span> <span class="n">b2_y1</span><span class="p">,</span> <span class="n">b2_y2</span><span class="p">,</span> <span class="n">eps</span><span class="p">):</span>
  151. <span class="sd">&quot;&quot;&quot;</span>
  152. <span class="sd"> Internal function for the use of calculate_bbox_iou_matrix and calculate_bbox_iou_elementwise functions</span>
  153. <span class="sd"> DO NOT CALL THIS FUNCTIONS DIRECTLY - use one of the functions mentioned above</span>
  154. <span class="sd"> &quot;&quot;&quot;</span>
  155. <span class="c1"># Intersection area</span>
  156. <span class="n">intersection_area</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">b1_x2</span><span class="p">,</span> <span class="n">b2_x2</span><span class="p">)</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">b1_x1</span><span class="p">,</span> <span class="n">b2_x1</span><span class="p">))</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">*</span> \
  157. <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">b1_y2</span><span class="p">,</span> <span class="n">b2_y2</span><span class="p">)</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">b1_y1</span><span class="p">,</span> <span class="n">b2_y1</span><span class="p">))</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
  158. <span class="c1"># Union Area</span>
  159. <span class="n">w1</span><span class="p">,</span> <span class="n">h1</span> <span class="o">=</span> <span class="n">b1_x2</span> <span class="o">-</span> <span class="n">b1_x1</span><span class="p">,</span> <span class="n">b1_y2</span> <span class="o">-</span> <span class="n">b1_y1</span>
  160. <span class="n">w2</span><span class="p">,</span> <span class="n">h2</span> <span class="o">=</span> <span class="n">b2_x2</span> <span class="o">-</span> <span class="n">b2_x1</span><span class="p">,</span> <span class="n">b2_y2</span> <span class="o">-</span> <span class="n">b2_y1</span>
  161. <span class="n">union_area</span> <span class="o">=</span> <span class="n">w1</span> <span class="o">*</span> <span class="n">h1</span> <span class="o">+</span> <span class="n">w2</span> <span class="o">*</span> <span class="n">h2</span> <span class="o">-</span> <span class="n">intersection_area</span> <span class="o">+</span> <span class="n">eps</span>
  162. <span class="n">iou</span> <span class="o">=</span> <span class="n">intersection_area</span> <span class="o">/</span> <span class="n">union_area</span> <span class="c1"># iou</span>
  163. <span class="k">if</span> <span class="n">GIoU</span> <span class="ow">or</span> <span class="n">DIoU</span> <span class="ow">or</span> <span class="n">CIoU</span><span class="p">:</span>
  164. <span class="n">cw</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">b1_x2</span><span class="p">,</span> <span class="n">b2_x2</span><span class="p">)</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">b1_x1</span><span class="p">,</span> <span class="n">b2_x1</span><span class="p">)</span> <span class="c1"># convex (smallest enclosing box) width</span>
  165. <span class="n">ch</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">b1_y2</span><span class="p">,</span> <span class="n">b2_y2</span><span class="p">)</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">b1_y1</span><span class="p">,</span> <span class="n">b2_y1</span><span class="p">)</span> <span class="c1"># convex height</span>
  166. <span class="c1"># Generalized IoU https://arxiv.org/pdf/1902.09630.pdf</span>
  167. <span class="k">if</span> <span class="n">GIoU</span><span class="p">:</span>
  168. <span class="n">c_area</span> <span class="o">=</span> <span class="n">cw</span> <span class="o">*</span> <span class="n">ch</span> <span class="o">+</span> <span class="n">eps</span> <span class="c1"># convex area</span>
  169. <span class="n">iou</span> <span class="o">-=</span> <span class="p">(</span><span class="n">c_area</span> <span class="o">-</span> <span class="n">union_area</span><span class="p">)</span> <span class="o">/</span> <span class="n">c_area</span> <span class="c1"># GIoU</span>
  170. <span class="c1"># Distance or Complete IoU https://arxiv.org/abs/1911.08287v1</span>
  171. <span class="k">if</span> <span class="n">DIoU</span> <span class="ow">or</span> <span class="n">CIoU</span><span class="p">:</span>
  172. <span class="c1"># convex diagonal squared</span>
  173. <span class="n">c2</span> <span class="o">=</span> <span class="n">cw</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">+</span> <span class="n">ch</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">+</span> <span class="n">eps</span>
  174. <span class="c1"># centerpoint distance squared</span>
  175. <span class="n">rho2</span> <span class="o">=</span> <span class="p">((</span><span class="n">b2_x1</span> <span class="o">+</span> <span class="n">b2_x2</span> <span class="o">-</span> <span class="n">b1_x1</span> <span class="o">-</span> <span class="n">b1_x2</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">+</span> <span class="p">(</span><span class="n">b2_y1</span> <span class="o">+</span> <span class="n">b2_y2</span> <span class="o">-</span> <span class="n">b1_y1</span> <span class="o">-</span> <span class="n">b1_y2</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="mi">4</span>
  176. <span class="k">if</span> <span class="n">DIoU</span><span class="p">:</span>
  177. <span class="n">iou</span> <span class="o">-=</span> <span class="n">rho2</span> <span class="o">/</span> <span class="n">c2</span> <span class="c1"># DIoU</span>
  178. <span class="k">elif</span> <span class="n">CIoU</span><span class="p">:</span> <span class="c1"># https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47</span>
  179. <span class="n">v</span> <span class="o">=</span> <span class="p">(</span><span class="mi">4</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">pi</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">pow</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">atan</span><span class="p">(</span><span class="n">w2</span> <span class="o">/</span> <span class="n">h2</span><span class="p">)</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">atan</span><span class="p">(</span><span class="n">w1</span> <span class="o">/</span> <span class="n">h1</span><span class="p">),</span> <span class="mi">2</span><span class="p">)</span>
  180. <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
  181. <span class="n">alpha</span> <span class="o">=</span> <span class="n">v</span> <span class="o">/</span> <span class="p">((</span><span class="mi">1</span> <span class="o">+</span> <span class="n">eps</span><span class="p">)</span> <span class="o">-</span> <span class="n">iou</span> <span class="o">+</span> <span class="n">v</span><span class="p">)</span>
  182. <span class="n">iou</span> <span class="o">-=</span> <span class="p">(</span><span class="n">rho2</span> <span class="o">/</span> <span class="n">c2</span> <span class="o">+</span> <span class="n">v</span> <span class="o">*</span> <span class="n">alpha</span><span class="p">)</span> <span class="c1"># CIoU</span>
  183. <span class="k">return</span> <span class="n">iou</span>
  184. <div class="viewcode-block" id="calculate_bbox_iou_matrix"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.calculate_bbox_iou_matrix">[docs]</a><span class="k">def</span> <span class="nf">calculate_bbox_iou_matrix</span><span class="p">(</span><span class="n">box1</span><span class="p">,</span> <span class="n">box2</span><span class="p">,</span> <span class="n">x1y1x2y2</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">GIoU</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span> <span class="n">DIoU</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">CIoU</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-9</span><span class="p">):</span>
  185. <span class="sd">&quot;&quot;&quot;</span>
  186. <span class="sd"> calculate iou matrix containing the iou of every couple iuo(i,j) where i is in box1 and j is in box2</span>
  187. <span class="sd"> :param box1: a 2D tensor of boxes (shape N x 4)</span>
  188. <span class="sd"> :param box2: a 2D tensor of boxes (shape M x 4)</span>
  189. <span class="sd"> :param x1y1x2y2: boxes format is x1y1x2y2 (True) or xywh where xy is the center (False)</span>
  190. <span class="sd"> :return: a 2D iou matrix (shape NxM)</span>
  191. <span class="sd"> &quot;&quot;&quot;</span>
  192. <span class="k">if</span> <span class="n">box1</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
  193. <span class="n">box1</span> <span class="o">=</span> <span class="n">box1</span><span class="o">.</span><span class="n">T</span>
  194. <span class="c1"># Get the coordinates of bounding boxes</span>
  195. <span class="k">if</span> <span class="n">x1y1x2y2</span><span class="p">:</span> <span class="c1"># x1, y1, x2, y2 = box1</span>
  196. <span class="n">b1_x1</span><span class="p">,</span> <span class="n">b1_y1</span><span class="p">,</span> <span class="n">b1_x2</span><span class="p">,</span> <span class="n">b1_y2</span> <span class="o">=</span> <span class="n">box1</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">box1</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">box1</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">box1</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span>
  197. <span class="n">b2_x1</span><span class="p">,</span> <span class="n">b2_y1</span><span class="p">,</span> <span class="n">b2_x2</span><span class="p">,</span> <span class="n">b2_y2</span> <span class="o">=</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">],</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span>
  198. <span class="k">else</span><span class="p">:</span> <span class="c1"># x, y, w, h = box1</span>
  199. <span class="n">b1_x1</span><span class="p">,</span> <span class="n">b1_x2</span> <span class="o">=</span> <span class="n">box1</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">box1</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">,</span> <span class="n">box1</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">box1</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span>
  200. <span class="n">b1_y1</span><span class="p">,</span> <span class="n">b1_y2</span> <span class="o">=</span> <span class="n">box1</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">box1</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">,</span> <span class="n">box1</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">box1</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span>
  201. <span class="n">b2_x1</span><span class="p">,</span> <span class="n">b2_x2</span> <span class="o">=</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">,</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span>
  202. <span class="n">b2_y1</span><span class="p">,</span> <span class="n">b2_y2</span> <span class="o">=</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span><span class="p">,</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">/</span> <span class="mi">2</span>
  203. <span class="n">b1_x1</span><span class="p">,</span> <span class="n">b1_y1</span><span class="p">,</span> <span class="n">b1_x2</span><span class="p">,</span> <span class="n">b1_y2</span> <span class="o">=</span> <span class="n">b1_x1</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">b1_y1</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">b1_x2</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">b1_y2</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
  204. <span class="k">return</span> <span class="n">_iou</span><span class="p">(</span><span class="n">CIoU</span><span class="p">,</span> <span class="n">DIoU</span><span class="p">,</span> <span class="n">GIoU</span><span class="p">,</span> <span class="n">b1_x1</span><span class="p">,</span> <span class="n">b1_x2</span><span class="p">,</span> <span class="n">b1_y1</span><span class="p">,</span> <span class="n">b1_y2</span><span class="p">,</span> <span class="n">b2_x1</span><span class="p">,</span> <span class="n">b2_x2</span><span class="p">,</span> <span class="n">b2_y1</span><span class="p">,</span> <span class="n">b2_y2</span><span class="p">,</span> <span class="n">eps</span><span class="p">)</span></div>
  205. <div class="viewcode-block" id="calc_bbox_iou_matrix"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.calc_bbox_iou_matrix">[docs]</a><span class="k">def</span> <span class="nf">calc_bbox_iou_matrix</span><span class="p">(</span><span class="n">pred</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
  206. <span class="sd">&quot;&quot;&quot;</span>
  207. <span class="sd"> calculate iou for every pair of boxes in the boxes vector</span>
  208. <span class="sd"> :param pred: a 3-dimensional tensor containing all boxes for a batch of images [N, num_boxes, 4], where</span>
  209. <span class="sd"> each box format is [x1,y1,x2,y2]</span>
  210. <span class="sd"> :return: a 3-dimensional matrix where M_i_j_k is the iou of box j and box k of the i&#39;th image in the batch</span>
  211. <span class="sd"> &quot;&quot;&quot;</span>
  212. <span class="n">box</span> <span class="o">=</span> <span class="n">pred</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="c1">#</span>
  213. <span class="n">b1_x1</span><span class="p">,</span> <span class="n">b1_y1</span> <span class="o">=</span> <span class="n">box</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">box</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
  214. <span class="n">b1_x2</span><span class="p">,</span> <span class="n">b1_y2</span> <span class="o">=</span> <span class="n">box</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">box</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">3</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
  215. <span class="n">b2_x1</span> <span class="o">=</span> <span class="n">b1_x1</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
  216. <span class="n">b2_x2</span> <span class="o">=</span> <span class="n">b1_x2</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
  217. <span class="n">b2_y1</span> <span class="o">=</span> <span class="n">b1_y1</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
  218. <span class="n">b2_y2</span> <span class="o">=</span> <span class="n">b1_y2</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
  219. <span class="n">intersection_area</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">b1_x2</span><span class="p">,</span> <span class="n">b2_x2</span><span class="p">)</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">b1_x1</span><span class="p">,</span> <span class="n">b2_x1</span><span class="p">))</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">*</span> \
  220. <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">b1_y2</span><span class="p">,</span> <span class="n">b2_y2</span><span class="p">)</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">b1_y1</span><span class="p">,</span> <span class="n">b2_y1</span><span class="p">))</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
  221. <span class="c1"># Union Area</span>
  222. <span class="n">w1</span><span class="p">,</span> <span class="n">h1</span> <span class="o">=</span> <span class="n">b1_x2</span> <span class="o">-</span> <span class="n">b1_x1</span><span class="p">,</span> <span class="n">b1_y2</span> <span class="o">-</span> <span class="n">b1_y1</span>
  223. <span class="n">w2</span><span class="p">,</span> <span class="n">h2</span> <span class="o">=</span> <span class="n">b2_x2</span> <span class="o">-</span> <span class="n">b2_x1</span><span class="p">,</span> <span class="n">b2_y2</span> <span class="o">-</span> <span class="n">b2_y1</span>
  224. <span class="n">union_area</span> <span class="o">=</span> <span class="p">(</span><span class="n">w1</span> <span class="o">*</span> <span class="n">h1</span> <span class="o">+</span> <span class="mf">1e-16</span><span class="p">)</span> <span class="o">+</span> <span class="n">w2</span> <span class="o">*</span> <span class="n">h2</span> <span class="o">-</span> <span class="n">intersection_area</span>
  225. <span class="n">ious</span> <span class="o">=</span> <span class="n">intersection_area</span> <span class="o">/</span> <span class="n">union_area</span>
  226. <span class="k">return</span> <span class="n">ious</span></div>
  227. <div class="viewcode-block" id="change_bbox_bounds_for_image_size"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.change_bbox_bounds_for_image_size">[docs]</a><span class="k">def</span> <span class="nf">change_bbox_bounds_for_image_size</span><span class="p">(</span><span class="n">boxes</span><span class="p">,</span> <span class="n">img_shape</span><span class="p">):</span>
  228. <span class="c1"># CLIP BOUNDING XYXY BOUNDING BOXES TO IMAGE SHAPE (HEIGHT, WIDTH)</span>
  229. <span class="n">boxes</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">]]</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">]]</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="n">img_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
  230. <span class="n">boxes</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">]]</span> <span class="o">=</span> <span class="n">boxes</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">]]</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="n">img_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
  231. <span class="k">return</span> <span class="n">boxes</span></div>
  232. <div class="viewcode-block" id="DetectionPostPredictionCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.DetectionPostPredictionCallback">[docs]</a><span class="k">class</span> <span class="nc">DetectionPostPredictionCallback</span><span class="p">(</span><span class="n">ABC</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
  233. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
  234. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
  235. <div class="viewcode-block" id="DetectionPostPredictionCallback.forward"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.DetectionPostPredictionCallback.forward">[docs]</a> <span class="nd">@abstractmethod</span>
  236. <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
  237. <span class="sd">&quot;&quot;&quot;</span>
  238. <span class="sd"> :param x: the output of your model</span>
  239. <span class="sd"> :param device: the device to move all output tensors into</span>
  240. <span class="sd"> :return: a list with length batch_size, each item in the list is a detections</span>
  241. <span class="sd"> with shape: nx6 (x1, y1, x2, y2, confidence, class) where x and y are in range [0,1]</span>
  242. <span class="sd"> &quot;&quot;&quot;</span>
  243. <span class="k">raise</span> <span class="ne">NotImplementedError</span></div></div>
  244. <div class="viewcode-block" id="IouThreshold"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.IouThreshold">[docs]</a><span class="k">class</span> <span class="nc">IouThreshold</span><span class="p">(</span><span class="nb">tuple</span><span class="p">,</span> <span class="n">Enum</span><span class="p">):</span>
  245. <span class="n">MAP_05</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">)</span>
  246. <span class="n">MAP_05_TO_095</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.95</span><span class="p">)</span>
  247. <div class="viewcode-block" id="IouThreshold.is_range"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.IouThreshold.is_range">[docs]</a> <span class="k">def</span> <span class="nf">is_range</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  248. <span class="k">return</span> <span class="bp">self</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">!=</span> <span class="bp">self</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span></div>
  249. <div class="viewcode-block" id="IouThreshold.to_tensor"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.IouThreshold.to_tensor">[docs]</a> <span class="k">def</span> <span class="nf">to_tensor</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  250. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_range</span><span class="p">():</span>
  251. <span class="n">n_iou_thresh</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="nb">round</span><span class="p">((</span><span class="bp">self</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="bp">self</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">/</span> <span class="mf">0.05</span><span class="p">))</span> <span class="o">+</span> <span class="mi">1</span>
  252. <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="bp">self</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">n_iou_thresh</span><span class="p">)</span>
  253. <span class="k">else</span><span class="p">:</span>
  254. <span class="n">n_iou_thresh</span> <span class="o">=</span> <span class="mi">1</span>
  255. <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="bp">self</span><span class="p">[</span><span class="mi">0</span><span class="p">]])</span></div></div>
  256. <div class="viewcode-block" id="box_iou"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.box_iou">[docs]</a><span class="k">def</span> <span class="nf">box_iou</span><span class="p">(</span><span class="n">box1</span><span class="p">,</span> <span class="n">box2</span><span class="p">):</span>
  257. <span class="c1"># https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py</span>
  258. <span class="sd">&quot;&quot;&quot;</span>
  259. <span class="sd"> Return intersection-over-union (Jaccard index) of boxes.</span>
  260. <span class="sd"> Both sets of boxes are expected to be in (x1, y1, x2, y2) format.</span>
  261. <span class="sd"> Arguments:</span>
  262. <span class="sd"> box1 (Tensor[N, 4])</span>
  263. <span class="sd"> box2 (Tensor[M, 4])</span>
  264. <span class="sd"> Returns:</span>
  265. <span class="sd"> iou (Tensor[N, M]): the NxM matrix containing the pairwise</span>
  266. <span class="sd"> IoU values for every element in boxes1 and boxes2</span>
  267. <span class="sd"> &quot;&quot;&quot;</span>
  268. <span class="k">def</span> <span class="nf">box_area</span><span class="p">(</span><span class="n">box</span><span class="p">):</span>
  269. <span class="c1"># box = 4xn</span>
  270. <span class="k">return</span> <span class="p">(</span><span class="n">box</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">box</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">*</span> <span class="p">(</span><span class="n">box</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">-</span> <span class="n">box</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
  271. <span class="n">area1</span> <span class="o">=</span> <span class="n">box_area</span><span class="p">(</span><span class="n">box1</span><span class="o">.</span><span class="n">T</span><span class="p">)</span>
  272. <span class="n">area2</span> <span class="o">=</span> <span class="n">box_area</span><span class="p">(</span><span class="n">box2</span><span class="o">.</span><span class="n">T</span><span class="p">)</span>
  273. <span class="c1"># inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)</span>
  274. <span class="n">inter</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">box1</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="mi">2</span><span class="p">:],</span> <span class="n">box2</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:])</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">box1</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="n">box2</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]))</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span>
  275. <span class="k">return</span> <span class="n">inter</span> <span class="o">/</span> <span class="p">(</span><span class="n">area1</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">+</span> <span class="n">area2</span> <span class="o">-</span> <span class="n">inter</span><span class="p">)</span> <span class="c1"># iou = inter / (area1 + area2 - inter)</span></div>
  276. <div class="viewcode-block" id="non_max_suppression"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.non_max_suppression">[docs]</a><span class="k">def</span> <span class="nf">non_max_suppression</span><span class="p">(</span><span class="n">prediction</span><span class="p">,</span> <span class="n">conf_thres</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">iou_thres</span><span class="o">=</span><span class="mf">0.6</span><span class="p">,</span>
  277. <span class="n">multi_label_per_box</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">with_confidence</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
  278. <span class="sd">&quot;&quot;&quot;</span>
  279. <span class="sd"> Performs Non-Maximum Suppression (NMS) on inference results</span>
  280. <span class="sd"> :param prediction: raw model prediction</span>
  281. <span class="sd"> :param conf_thres: below the confidence threshold - prediction are discarded</span>
  282. <span class="sd"> :param iou_thres: IoU threshold for the nms algorithm</span>
  283. <span class="sd"> :param multi_label_per_box: whether to use re-use each box with all possible labels</span>
  284. <span class="sd"> (instead of the maximum confidence all confidences above threshold</span>
  285. <span class="sd"> will be sent to NMS); by default is set to True</span>
  286. <span class="sd"> :param with_confidence: whether to multiply objectness score with class score.</span>
  287. <span class="sd"> usually valid for Yolo models only.</span>
  288. <span class="sd"> :return: (x1, y1, x2, y2, object_conf, class_conf, class)</span>
  289. <span class="sd"> Returns:</span>
  290. <span class="sd"> detections with shape: nx6 (x1, y1, x2, y2, conf, cls)</span>
  291. <span class="sd"> &quot;&quot;&quot;</span>
  292. <span class="n">candidates_above_thres</span> <span class="o">=</span> <span class="n">prediction</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">conf_thres</span> <span class="c1"># filter by confidence</span>
  293. <span class="n">output</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">prediction</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
  294. <span class="k">for</span> <span class="n">image_idx</span><span class="p">,</span> <span class="n">pred</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">prediction</span><span class="p">):</span>
  295. <span class="n">pred</span> <span class="o">=</span> <span class="n">pred</span><span class="p">[</span><span class="n">candidates_above_thres</span><span class="p">[</span><span class="n">image_idx</span><span class="p">]]</span> <span class="c1"># confident</span>
  296. <span class="k">if</span> <span class="ow">not</span> <span class="n">pred</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]:</span> <span class="c1"># If none remain process next image</span>
  297. <span class="k">continue</span>
  298. <span class="k">if</span> <span class="n">with_confidence</span><span class="p">:</span>
  299. <span class="n">pred</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">:]</span> <span class="o">*=</span> <span class="n">pred</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">:</span><span class="mi">5</span><span class="p">]</span> <span class="c1"># multiply objectness score with class score</span>
  300. <span class="n">box</span> <span class="o">=</span> <span class="n">convert_xywh_bbox_to_xyxy</span><span class="p">(</span><span class="n">pred</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">])</span> <span class="c1"># xywh to xyxy</span>
  301. <span class="c1"># Detections matrix nx6 (xyxy, conf, cls)</span>
  302. <span class="k">if</span> <span class="n">multi_label_per_box</span><span class="p">:</span> <span class="c1"># try for all good confidence classes</span>
  303. <span class="n">i</span><span class="p">,</span> <span class="n">j</span> <span class="o">=</span> <span class="p">(</span><span class="n">pred</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">:]</span> <span class="o">&gt;</span> <span class="n">conf_thres</span><span class="p">)</span><span class="o">.</span><span class="n">nonzero</span><span class="p">(</span><span class="n">as_tuple</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span><span class="o">.</span><span class="n">T</span>
  304. <span class="n">pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">box</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">pred</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span> <span class="o">+</span> <span class="mi">5</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span> <span class="n">j</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span><span class="o">.</span><span class="n">float</span><span class="p">()),</span> <span class="mi">1</span><span class="p">)</span>
  305. <span class="k">else</span><span class="p">:</span> <span class="c1"># best class only</span>
  306. <span class="n">conf</span><span class="p">,</span> <span class="n">j</span> <span class="o">=</span> <span class="n">pred</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">:]</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  307. <span class="n">pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">box</span><span class="p">,</span> <span class="n">conf</span><span class="p">,</span> <span class="n">j</span><span class="o">.</span><span class="n">float</span><span class="p">()),</span> <span class="mi">1</span><span class="p">)[</span><span class="n">conf</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">&gt;</span> <span class="n">conf_thres</span><span class="p">]</span>
  308. <span class="k">if</span> <span class="ow">not</span> <span class="n">pred</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]:</span> <span class="c1"># If none remain process next image</span>
  309. <span class="k">continue</span>
  310. <span class="c1"># Apply torch batched NMS algorithm</span>
  311. <span class="n">boxes</span><span class="p">,</span> <span class="n">scores</span><span class="p">,</span> <span class="n">cls_idx</span> <span class="o">=</span> <span class="n">pred</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">],</span> <span class="n">pred</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">],</span> <span class="n">pred</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">]</span>
  312. <span class="n">idx_to_keep</span> <span class="o">=</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">boxes</span><span class="o">.</span><span class="n">batched_nms</span><span class="p">(</span><span class="n">boxes</span><span class="p">,</span> <span class="n">scores</span><span class="p">,</span> <span class="n">cls_idx</span><span class="p">,</span> <span class="n">iou_thres</span><span class="p">)</span>
  313. <span class="n">output</span><span class="p">[</span><span class="n">image_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">pred</span><span class="p">[</span><span class="n">idx_to_keep</span><span class="p">]</span>
  314. <span class="k">return</span> <span class="n">output</span></div>
  315. <div class="viewcode-block" id="matrix_non_max_suppression"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.matrix_non_max_suppression">[docs]</a><span class="k">def</span> <span class="nf">matrix_non_max_suppression</span><span class="p">(</span><span class="n">pred</span><span class="p">,</span> <span class="n">conf_thres</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span> <span class="n">kernel</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;gaussian&#39;</span><span class="p">,</span>
  316. <span class="n">sigma</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">3.0</span><span class="p">,</span> <span class="n">max_num_of_detections</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">500</span><span class="p">):</span>
  317. <span class="sd">&quot;&quot;&quot;Performs Matrix Non-Maximum Suppression (NMS) on inference results</span>
  318. <span class="sd"> https://arxiv.org/pdf/1912.04488.pdf</span>
  319. <span class="sd"> :param pred: raw model prediction (in test mode) - a Tensor of shape [batch, num_predictions, 85]</span>
  320. <span class="sd"> where each item format is (x, y, w, h, object_conf, class_conf, ... 80 classes score ...)</span>
  321. <span class="sd"> :param conf_thres: below the confidence threshold - prediction are discarded</span>
  322. <span class="sd"> :param kernel: type of kernel to use [&#39;gaussian&#39;, &#39;linear&#39;]</span>
  323. <span class="sd"> :param sigma: sigma for the gussian kernel</span>
  324. <span class="sd"> :param max_num_of_detections: maximum number of boxes to output</span>
  325. <span class="sd"> :return: list of (x1, y1, x2, y2, object_conf, class_conf, class)</span>
  326. <span class="sd"> Returns:</span>
  327. <span class="sd"> detections list with shape: (x1, y1, x2, y2, conf, cls)</span>
  328. <span class="sd"> &quot;&quot;&quot;</span>
  329. <span class="c1"># MULTIPLY CONF BY CLASS CONF TO GET COMBINED CONFIDENCE</span>
  330. <span class="n">class_conf</span><span class="p">,</span> <span class="n">class_pred</span> <span class="o">=</span> <span class="n">pred</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">5</span><span class="p">:]</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span>
  331. <span class="n">pred</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">*=</span> <span class="n">class_conf</span>
  332. <span class="c1"># BOX (CENTER X, CENTER Y, WIDTH, HEIGHT) TO (X1, Y1, X2, Y2)</span>
  333. <span class="n">pred</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">=</span> <span class="n">convert_xywh_bbox_to_xyxy</span><span class="p">(</span><span class="n">pred</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">])</span>
  334. <span class="c1"># DETECTIONS ORDERED AS (x1y1x2y2, obj_conf, class_conf, class_pred)</span>
  335. <span class="n">pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">pred</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:</span><span class="mi">5</span><span class="p">],</span> <span class="n">class_pred</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">2</span><span class="p">)),</span> <span class="mi">2</span><span class="p">)</span>
  336. <span class="c1"># SORT DETECTIONS BY DECREASING CONFIDENCE SCORES</span>
  337. <span class="n">sort_ind</span> <span class="o">=</span> <span class="p">(</span><span class="o">-</span><span class="n">pred</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">4</span><span class="p">])</span><span class="o">.</span><span class="n">argsort</span><span class="p">()</span>
  338. <span class="n">pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">pred</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">sort_ind</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">pred</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])])[:,</span> <span class="mi">0</span><span class="p">:</span><span class="n">max_num_of_detections</span><span class="p">]</span>
  339. <span class="n">ious</span> <span class="o">=</span> <span class="n">calc_bbox_iou_matrix</span><span class="p">(</span><span class="n">pred</span><span class="p">)</span>
  340. <span class="n">ious</span> <span class="o">=</span> <span class="n">ious</span><span class="o">.</span><span class="n">triu</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
  341. <span class="c1"># CREATE A LABELS MASK, WE WANT ONLY BOXES WITH THE SAME LABEL TO AFFECT EACH OTHER</span>
  342. <span class="n">labels</span> <span class="o">=</span> <span class="n">pred</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">5</span><span class="p">:]</span>
  343. <span class="n">labeles_matrix</span> <span class="o">=</span> <span class="p">(</span><span class="n">labels</span> <span class="o">==</span> <span class="n">labels</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">float</span><span class="p">()</span><span class="o">.</span><span class="n">triu</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
  344. <span class="n">ious</span> <span class="o">*=</span> <span class="n">labeles_matrix</span>
  345. <span class="n">ious_cmax</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">ious</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
  346. <span class="n">ious_cmax</span> <span class="o">=</span> <span class="n">ious_cmax</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">max_num_of_detections</span><span class="p">)</span>
  347. <span class="k">if</span> <span class="n">kernel</span> <span class="o">==</span> <span class="s1">&#39;gaussian&#39;</span><span class="p">:</span>
  348. <span class="n">decay_matrix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span> <span class="o">*</span> <span class="n">sigma</span> <span class="o">*</span> <span class="p">(</span><span class="n">ious</span> <span class="o">**</span> <span class="mi">2</span><span class="p">))</span>
  349. <span class="n">compensate_matrix</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span> <span class="o">*</span> <span class="n">sigma</span> <span class="o">*</span> <span class="p">(</span><span class="n">ious_cmax</span> <span class="o">**</span> <span class="mi">2</span><span class="p">))</span>
  350. <span class="n">decay</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="p">(</span><span class="n">decay_matrix</span> <span class="o">/</span> <span class="n">compensate_matrix</span><span class="p">)</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
  351. <span class="k">else</span><span class="p">:</span>
  352. <span class="n">decay</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">ious</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">ious_cmax</span><span class="p">)</span>
  353. <span class="n">decay</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">decay</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
  354. <span class="n">pred</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">*=</span> <span class="n">decay</span>
  355. <span class="n">output</span> <span class="o">=</span> <span class="p">[</span><span class="n">pred</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">pred</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="p">:,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">conf_thres</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">pred</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])]</span>
  356. <span class="k">return</span> <span class="n">output</span></div>
  357. <div class="viewcode-block" id="NMS_Type"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.NMS_Type">[docs]</a><span class="k">class</span> <span class="nc">NMS_Type</span><span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Enum</span><span class="p">):</span>
  358. <span class="sd">&quot;&quot;&quot;</span>
  359. <span class="sd"> Type of non max suppression algorithm that can be used for post processing detection</span>
  360. <span class="sd"> &quot;&quot;&quot;</span>
  361. <span class="n">ITERATIVE</span> <span class="o">=</span> <span class="s1">&#39;iterative&#39;</span>
  362. <span class="n">MATRIX</span> <span class="o">=</span> <span class="s1">&#39;matrix&#39;</span></div>
  363. <div class="viewcode-block" id="undo_image_preprocessing"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.undo_image_preprocessing">[docs]</a><span class="k">def</span> <span class="nf">undo_image_preprocessing</span><span class="p">(</span><span class="n">im_tensor</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
  364. <span class="sd">&quot;&quot;&quot;</span>
  365. <span class="sd"> :param im_tensor: images in a batch after preprocessing for inference, RGB, (B, C, H, W)</span>
  366. <span class="sd"> :return: images in a batch in cv2 format, BGR, (B, H, W, C)</span>
  367. <span class="sd"> &quot;&quot;&quot;</span>
  368. <span class="n">im_np</span> <span class="o">=</span> <span class="n">im_tensor</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
  369. <span class="n">im_np</span> <span class="o">=</span> <span class="n">im_np</span><span class="p">[:,</span> <span class="p">::</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
  370. <span class="n">im_np</span> <span class="o">*=</span> <span class="mf">255.</span>
  371. <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">ascontiguousarray</span><span class="p">(</span><span class="n">im_np</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span></div>
  372. <div class="viewcode-block" id="DetectionVisualization"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.DetectionVisualization">[docs]</a><span class="k">class</span> <span class="nc">DetectionVisualization</span><span class="p">:</span>
  373. <span class="nd">@staticmethod</span>
  374. <span class="k">def</span> <span class="nf">_generate_color_mapping</span><span class="p">(</span><span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]]:</span>
  375. <span class="sd">&quot;&quot;&quot;</span>
  376. <span class="sd"> Generate a unique BGR color for each class</span>
  377. <span class="sd"> &quot;&quot;&quot;</span>
  378. <span class="n">cmap</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">cm</span><span class="o">.</span><span class="n">get_cmap</span><span class="p">(</span><span class="s1">&#39;gist_rainbow&#39;</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span>
  379. <span class="n">colors</span> <span class="o">=</span> <span class="p">[</span><span class="n">cmap</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="nb">bytes</span><span class="o">=</span><span class="kc">True</span><span class="p">)[:</span><span class="mi">3</span><span class="p">][::</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_classes</span><span class="p">)]</span>
  380. <span class="k">return</span> <span class="p">[</span><span class="nb">tuple</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">v</span><span class="p">)</span> <span class="k">for</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">c</span><span class="p">)</span> <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="n">colors</span><span class="p">]</span>
  381. <span class="nd">@staticmethod</span>
  382. <span class="k">def</span> <span class="nf">_draw_box_title</span><span class="p">(</span><span class="n">color_mapping</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span> <span class="n">class_names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="n">box_thickness</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
  383. <span class="n">image_np</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">x1</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">y1</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">x2</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">y2</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">class_id</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
  384. <span class="n">pred_conf</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="n">is_target</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
  385. <span class="n">color</span> <span class="o">=</span> <span class="n">color_mapping</span><span class="p">[</span><span class="n">class_id</span><span class="p">]</span>
  386. <span class="n">class_name</span> <span class="o">=</span> <span class="n">class_names</span><span class="p">[</span><span class="n">class_id</span><span class="p">]</span>
  387. <span class="c1"># Draw the box</span>
  388. <span class="n">image_np</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">rectangle</span><span class="p">(</span><span class="n">image_np</span><span class="p">,</span> <span class="p">(</span><span class="n">x1</span><span class="p">,</span> <span class="n">y1</span><span class="p">),</span> <span class="p">(</span><span class="n">x2</span><span class="p">,</span> <span class="n">y2</span><span class="p">),</span> <span class="n">color</span><span class="p">,</span> <span class="n">box_thickness</span><span class="p">)</span>
  389. <span class="c1"># Caption with class name and confidence if given</span>
  390. <span class="n">text_color</span> <span class="o">=</span> <span class="p">(</span><span class="mi">255</span><span class="p">,</span> <span class="mi">255</span><span class="p">,</span> <span class="mi">255</span><span class="p">)</span> <span class="c1"># white</span>
  391. <span class="k">if</span> <span class="n">is_target</span><span class="p">:</span>
  392. <span class="n">title</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;[GT] </span><span class="si">{</span><span class="n">class_name</span><span class="si">}</span><span class="s1">&#39;</span>
  393. <span class="k">if</span> <span class="ow">not</span> <span class="n">is_target</span><span class="p">:</span>
  394. <span class="n">title</span> <span class="o">=</span> <span class="sa">f</span><span class="s1">&#39;[Pred] </span><span class="si">{</span><span class="n">class_name</span><span class="si">}</span><span class="s1"> </span><span class="si">{</span><span class="nb">str</span><span class="p">(</span><span class="nb">round</span><span class="p">(</span><span class="n">pred_conf</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span> <span class="k">if</span> <span class="n">pred_conf</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="s2">&quot;&quot;</span><span class="si">}</span><span class="s1">&#39;</span>
  395. <span class="n">image_np</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">rectangle</span><span class="p">(</span><span class="n">image_np</span><span class="p">,</span> <span class="p">(</span><span class="n">x1</span><span class="p">,</span> <span class="n">y1</span> <span class="o">-</span> <span class="mi">15</span><span class="p">),</span> <span class="p">(</span><span class="n">x1</span> <span class="o">+</span> <span class="nb">len</span><span class="p">(</span><span class="n">title</span><span class="p">)</span> <span class="o">*</span> <span class="mi">10</span><span class="p">,</span> <span class="n">y1</span><span class="p">),</span> <span class="n">color</span><span class="p">,</span> <span class="n">cv2</span><span class="o">.</span><span class="n">FILLED</span><span class="p">)</span>
  396. <span class="n">image_np</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">putText</span><span class="p">(</span><span class="n">image_np</span><span class="p">,</span> <span class="n">title</span><span class="p">,</span> <span class="p">(</span><span class="n">x1</span><span class="p">,</span> <span class="n">y1</span> <span class="o">-</span> <span class="n">box_thickness</span><span class="p">),</span> <span class="mi">2</span><span class="p">,</span> <span class="mf">.5</span><span class="p">,</span> <span class="n">text_color</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">lineType</span><span class="o">=</span><span class="n">cv2</span><span class="o">.</span><span class="n">LINE_AA</span><span class="p">)</span>
  397. <span class="k">return</span> <span class="n">image_np</span>
  398. <span class="nd">@staticmethod</span>
  399. <span class="k">def</span> <span class="nf">_visualize_image</span><span class="p">(</span><span class="n">image_np</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">pred_boxes</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">target_boxes</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span>
  400. <span class="n">class_names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="n">box_thickness</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">gt_alpha</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">image_scale</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
  401. <span class="n">checkpoint_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">image_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
  402. <span class="n">image_np</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">image_np</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="n">fx</span><span class="o">=</span><span class="n">image_scale</span><span class="p">,</span> <span class="n">fy</span><span class="o">=</span><span class="n">image_scale</span><span class="p">,</span> <span class="n">interpolation</span><span class="o">=</span><span class="n">cv2</span><span class="o">.</span><span class="n">INTER_NEAREST</span><span class="p">)</span>
  403. <span class="n">color_mapping</span> <span class="o">=</span> <span class="n">DetectionVisualization</span><span class="o">.</span><span class="n">_generate_color_mapping</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">class_names</span><span class="p">))</span>
  404. <span class="c1"># Draw predictions</span>
  405. <span class="n">pred_boxes</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">4</span><span class="p">]</span> <span class="o">*=</span> <span class="n">image_scale</span>
  406. <span class="k">for</span> <span class="n">box</span> <span class="ow">in</span> <span class="n">pred_boxes</span><span class="p">:</span>
  407. <span class="n">image_np</span> <span class="o">=</span> <span class="n">DetectionVisualization</span><span class="o">.</span><span class="n">_draw_box_title</span><span class="p">(</span><span class="n">color_mapping</span><span class="p">,</span> <span class="n">class_names</span><span class="p">,</span> <span class="n">box_thickness</span><span class="p">,</span>
  408. <span class="n">image_np</span><span class="p">,</span> <span class="o">*</span><span class="n">box</span><span class="p">[:</span><span class="mi">4</span><span class="p">]</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">),</span>
  409. <span class="n">class_id</span><span class="o">=</span><span class="nb">int</span><span class="p">(</span><span class="n">box</span><span class="p">[</span><span class="mi">5</span><span class="p">]),</span> <span class="n">pred_conf</span><span class="o">=</span><span class="n">box</span><span class="p">[</span><span class="mi">4</span><span class="p">])</span>
  410. <span class="c1"># Draw ground truths</span>
  411. <span class="n">target_boxes_image</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">image_np</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span>
  412. <span class="k">for</span> <span class="n">box</span> <span class="ow">in</span> <span class="n">target_boxes</span><span class="p">:</span>
  413. <span class="n">target_boxes_image</span> <span class="o">=</span> <span class="n">DetectionVisualization</span><span class="o">.</span><span class="n">_draw_box_title</span><span class="p">(</span><span class="n">color_mapping</span><span class="p">,</span> <span class="n">class_names</span><span class="p">,</span> <span class="n">box_thickness</span><span class="p">,</span>
  414. <span class="n">target_boxes_image</span><span class="p">,</span> <span class="o">*</span><span class="n">box</span><span class="p">[</span><span class="mi">2</span><span class="p">:],</span>
  415. <span class="n">class_id</span><span class="o">=</span><span class="n">box</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">is_target</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  416. <span class="c1"># Transparent overlay of ground truth boxes</span>
  417. <span class="n">mask</span> <span class="o">=</span> <span class="n">target_boxes_image</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">bool</span><span class="p">)</span>
  418. <span class="n">image_np</span><span class="p">[</span><span class="n">mask</span><span class="p">]</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">addWeighted</span><span class="p">(</span><span class="n">image_np</span><span class="p">,</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">gt_alpha</span><span class="p">,</span> <span class="n">target_boxes_image</span><span class="p">,</span> <span class="n">gt_alpha</span><span class="p">,</span> <span class="mi">0</span><span class="p">)[</span><span class="n">mask</span><span class="p">]</span>
  419. <span class="k">if</span> <span class="n">checkpoint_dir</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  420. <span class="k">return</span> <span class="n">image_np</span>
  421. <span class="k">else</span><span class="p">:</span>
  422. <span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">)</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  423. <span class="n">cv2</span><span class="o">.</span><span class="n">imwrite</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="n">image_name</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;.jpg&#39;</span><span class="p">),</span> <span class="n">image_np</span><span class="p">)</span>
  424. <span class="nd">@staticmethod</span>
  425. <span class="k">def</span> <span class="nf">_scaled_ccwh_to_xyxy</span><span class="p">(</span><span class="n">target_boxes</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">h</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">w</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">image_scale</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
  426. <span class="sd">&quot;&quot;&quot;</span>
  427. <span class="sd"> Modifies target_boxes inplace</span>
  428. <span class="sd"> :param target_boxes: (c1, c2, w, h) boxes in [0, 1] range</span>
  429. <span class="sd"> :param h: image height</span>
  430. <span class="sd"> :param w: image width</span>
  431. <span class="sd"> :param image_scale: desired scale for the boxes w.r.t. w and h</span>
  432. <span class="sd"> :return: targets in (x1, y1, x2, y2) format</span>
  433. <span class="sd"> in range [0, w * self.image_scale] [0, h * self.image_scale]</span>
  434. <span class="sd"> &quot;&quot;&quot;</span>
  435. <span class="c1"># unscale</span>
  436. <span class="n">target_boxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">*=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">]])</span>
  437. <span class="c1"># x1 = c1 - w // 2; y1 = c2 - h // 2</span>
  438. <span class="n">target_boxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">-=</span> <span class="n">target_boxes</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">//</span> <span class="mi">2</span>
  439. <span class="n">target_boxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">-=</span> <span class="n">target_boxes</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">]</span> <span class="o">//</span> <span class="mi">2</span>
  440. <span class="c1"># x2 = w + x1; y2 = h + y1</span>
  441. <span class="n">target_boxes</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">]</span> <span class="o">+=</span> <span class="n">target_boxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span>
  442. <span class="n">target_boxes</span><span class="p">[:,</span> <span class="mi">5</span><span class="p">]</span> <span class="o">+=</span> <span class="n">target_boxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span>
  443. <span class="n">target_boxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span> <span class="o">*=</span> <span class="n">image_scale</span>
  444. <span class="n">target_boxes</span> <span class="o">=</span> <span class="n">target_boxes</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">)</span>
  445. <span class="k">return</span> <span class="n">target_boxes</span>
  446. <div class="viewcode-block" id="DetectionVisualization.visualize_batch"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.DetectionVisualization.visualize_batch">[docs]</a> <span class="nd">@staticmethod</span>
  447. <span class="k">def</span> <span class="nf">visualize_batch</span><span class="p">(</span><span class="n">image_tensor</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">pred_boxes</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">target_boxes</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
  448. <span class="n">batch_name</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">str</span><span class="p">],</span> <span class="n">class_names</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="n">checkpoint_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
  449. <span class="n">undo_preprocessing_func</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span> <span class="o">=</span> <span class="n">undo_image_preprocessing</span><span class="p">,</span>
  450. <span class="n">box_thickness</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span> <span class="n">image_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.</span><span class="p">,</span> <span class="n">gt_alpha</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">.4</span><span class="p">):</span>
  451. <span class="sd">&quot;&quot;&quot;</span>
  452. <span class="sd"> A helper function to visualize detections predicted by a network:</span>
  453. <span class="sd"> saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call.</span>
  454. <span class="sd"> Colors are generated on the fly: uniformly sampled from color wheel to support all given classes.</span>
  455. <span class="sd"> Adjustable:</span>
  456. <span class="sd"> * Ground truth box transparency;</span>
  457. <span class="sd"> * Box width;</span>
  458. <span class="sd"> * Image size (larger or smaller than what&#39;s provided)</span>
  459. <span class="sd"> :param image_tensor: rgb images, (B, H, W, 3)</span>
  460. <span class="sd"> :param pred_boxes: boxes after NMS for each image in a batch, each (Num_boxes, 6),</span>
  461. <span class="sd"> values on dim 1 are: x1, y1, x2, y2, confidence, class</span>
  462. <span class="sd"> :param target_boxes: (Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h</span>
  463. <span class="sd"> (coordinates scaled to [0, 1])</span>
  464. <span class="sd"> :param batch_name: id of the current batch to use for image naming</span>
  465. <span class="sd"> :param class_names: names of all classes, each on its own index</span>
  466. <span class="sd"> :param checkpoint_dir: a path where images with boxes will be saved. if None, the result images will</span>
  467. <span class="sd"> be returns as a list of numpy image arrays</span>
  468. <span class="sd"> :param undo_preprocessing_func: a function to convert preprocessed images tensor into a batch of cv2-like images</span>
  469. <span class="sd"> :param box_thickness: box line thickness in px</span>
  470. <span class="sd"> :param image_scale: scale of an image w.r.t. given image size,</span>
  471. <span class="sd"> e.g. incoming images are (320x320), use scale = 2. to preview in (640x640)</span>
  472. <span class="sd"> :param gt_alpha: a value in [0., 1.] transparency on ground truth boxes,</span>
  473. <span class="sd"> 0 for invisible, 1 for fully opaque</span>
  474. <span class="sd"> &quot;&quot;&quot;</span>
  475. <span class="n">image_np</span> <span class="o">=</span> <span class="n">undo_preprocessing_func</span><span class="p">(</span><span class="n">image_tensor</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span>
  476. <span class="n">targets</span> <span class="o">=</span> <span class="n">DetectionVisualization</span><span class="o">.</span><span class="n">_scaled_ccwh_to_xyxy</span><span class="p">(</span><span class="n">target_boxes</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="o">*</span><span class="n">image_np</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">:</span><span class="mi">3</span><span class="p">],</span>
  477. <span class="n">image_scale</span><span class="p">)</span>
  478. <span class="n">out_images</span> <span class="o">=</span> <span class="p">[]</span>
  479. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">image_np</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
  480. <span class="n">preds</span> <span class="o">=</span> <span class="n">pred_boxes</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="k">if</span> <span class="n">pred_boxes</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">np</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="mi">6</span><span class="p">))</span>
  481. <span class="n">targets_cur</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="n">targets</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">i</span><span class="p">]</span>
  482. <span class="n">image_name</span> <span class="o">=</span> <span class="s1">&#39;_&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="nb">str</span><span class="p">(</span><span class="n">batch_name</span><span class="p">),</span> <span class="nb">str</span><span class="p">(</span><span class="n">i</span><span class="p">)])</span>
  483. <span class="n">res_image</span> <span class="o">=</span> <span class="n">DetectionVisualization</span><span class="o">.</span><span class="n">_visualize_image</span><span class="p">(</span><span class="n">image_np</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">preds</span><span class="p">,</span> <span class="n">targets_cur</span><span class="p">,</span> <span class="n">class_names</span><span class="p">,</span> <span class="n">box_thickness</span><span class="p">,</span> <span class="n">gt_alpha</span><span class="p">,</span> <span class="n">image_scale</span><span class="p">,</span>
  484. <span class="n">checkpoint_dir</span><span class="p">,</span> <span class="n">image_name</span><span class="p">)</span>
  485. <span class="k">if</span> <span class="n">res_image</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  486. <span class="n">out_images</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">res_image</span><span class="p">)</span>
  487. <span class="k">return</span> <span class="n">out_images</span></div></div>
  488. <div class="viewcode-block" id="Anchors"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.Anchors">[docs]</a><span class="k">class</span> <span class="nc">Anchors</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
  489. <span class="sd">&quot;&quot;&quot;</span>
  490. <span class="sd"> A wrapper function to hold the anchors used by detection models such as Yolo</span>
  491. <span class="sd"> &quot;&quot;&quot;</span>
  492. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">anchors_list</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">],</span> <span class="n">strides</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]):</span>
  493. <span class="sd">&quot;&quot;&quot;</span>
  494. <span class="sd"> :param anchors_list: of the shape [[w1,h1,w2,h2,w3,h3], [w4,h4,w5,h5,w6,h6] .... where each sublist holds</span>
  495. <span class="sd"> the width and height of the anchors of a specific detection layer.</span>
  496. <span class="sd"> i.e. for a model with 3 detection layers, each containing 5 anchors the format will be a of 3 sublists of 10 numbers each</span>
  497. <span class="sd"> The width and height are in pixels (not relative to image size)</span>
  498. <span class="sd"> :param strides: a list containing the stride of the layers from which the detection heads are fed.</span>
  499. <span class="sd"> i.e. if the firs detection head is connected to the backbone after the input dimensions were reduces by 8, the first number will be 8</span>
  500. <span class="sd"> &quot;&quot;&quot;</span>
  501. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
  502. <span class="bp">self</span><span class="o">.</span><span class="n">__anchors_list</span> <span class="o">=</span> <span class="n">anchors_list</span>
  503. <span class="bp">self</span><span class="o">.</span><span class="n">__strides</span> <span class="o">=</span> <span class="n">strides</span>
  504. <span class="bp">self</span><span class="o">.</span><span class="n">_check_all_lists</span><span class="p">(</span><span class="n">anchors_list</span><span class="p">)</span>
  505. <span class="bp">self</span><span class="o">.</span><span class="n">_check_all_len_equal_and_even</span><span class="p">(</span><span class="n">anchors_list</span><span class="p">)</span>
  506. <span class="bp">self</span><span class="o">.</span><span class="n">_stride</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">strides</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">(),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
  507. <span class="n">anchors</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">(</span><span class="n">anchors_list</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">anchors_list</span><span class="p">),</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
  508. <span class="bp">self</span><span class="o">.</span><span class="n">_anchors</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">anchors</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">_stride</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
  509. <span class="bp">self</span><span class="o">.</span><span class="n">_anchor_grid</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">anchors</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">anchors_list</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
  510. <span class="nd">@staticmethod</span>
  511. <span class="k">def</span> <span class="nf">_check_all_lists</span><span class="p">(</span><span class="n">anchors</span><span class="p">:</span> <span class="nb">list</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
  512. <span class="k">for</span> <span class="n">a</span> <span class="ow">in</span> <span class="n">anchors</span><span class="p">:</span>
  513. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="p">(</span><span class="nb">list</span><span class="p">,</span> <span class="n">ListConfig</span><span class="p">)):</span>
  514. <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s1">&#39;All objects of anchors_list must be lists&#39;</span><span class="p">)</span>
  515. <span class="nd">@staticmethod</span>
  516. <span class="k">def</span> <span class="nf">_check_all_len_equal_and_even</span><span class="p">(</span><span class="n">anchors</span><span class="p">:</span> <span class="nb">list</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
  517. <span class="n">len_of_first</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">anchors</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
  518. <span class="k">for</span> <span class="n">a</span> <span class="ow">in</span> <span class="n">anchors</span><span class="p">:</span>
  519. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="o">!=</span> <span class="n">len_of_first</span><span class="p">:</span>
  520. <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s1">&#39;All objects of anchors_list must be of the same even length&#39;</span><span class="p">)</span>
  521. <span class="nd">@property</span>
  522. <span class="k">def</span> <span class="nf">stride</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">:</span>
  523. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_stride</span>
  524. <span class="nd">@property</span>
  525. <span class="k">def</span> <span class="nf">anchors</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">:</span>
  526. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_anchors</span>
  527. <span class="nd">@property</span>
  528. <span class="k">def</span> <span class="nf">anchor_grid</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">:</span>
  529. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_anchor_grid</span>
  530. <span class="nd">@property</span>
  531. <span class="k">def</span> <span class="nf">detection_layers_num</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
  532. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_anchors</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
  533. <span class="nd">@property</span>
  534. <span class="k">def</span> <span class="nf">num_anchors</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
  535. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_anchors</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
  536. <span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  537. <span class="k">return</span> <span class="sa">f</span><span class="s2">&quot;anchors_list: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">__anchors_list</span><span class="si">}</span><span class="s2"> strides: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">__strides</span><span class="si">}</span><span class="s2">&quot;</span></div>
  538. <div class="viewcode-block" id="xyxy2cxcywh"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.xyxy2cxcywh">[docs]</a><span class="k">def</span> <span class="nf">xyxy2cxcywh</span><span class="p">(</span><span class="n">bboxes</span><span class="p">):</span>
  539. <span class="sd">&quot;&quot;&quot;</span>
  540. <span class="sd"> Transforms bboxes from xyxy format to centerized xy wh format</span>
  541. <span class="sd"> :param bboxes: array, shaped (nboxes, 4)</span>
  542. <span class="sd"> :return: modified bboxes</span>
  543. <span class="sd"> &quot;&quot;&quot;</span>
  544. <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span>
  545. <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">-</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span>
  546. <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="mf">0.5</span>
  547. <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">*</span> <span class="mf">0.5</span>
  548. <span class="k">return</span> <span class="n">bboxes</span></div>
  549. <div class="viewcode-block" id="cxcywh2xyxy"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.cxcywh2xyxy">[docs]</a><span class="k">def</span> <span class="nf">cxcywh2xyxy</span><span class="p">(</span><span class="n">bboxes</span><span class="p">):</span>
  550. <span class="sd">&quot;&quot;&quot;</span>
  551. <span class="sd"> Transforms bboxes from centerized xy wh format to xyxy format</span>
  552. <span class="sd"> :param bboxes: array, shaped (nboxes, 4)</span>
  553. <span class="sd"> :return: modified bboxes</span>
  554. <span class="sd"> &quot;&quot;&quot;</span>
  555. <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">*</span> <span class="mf">0.5</span>
  556. <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="mf">0.5</span>
  557. <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">+</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span>
  558. <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">+</span> <span class="n">bboxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span>
  559. <span class="k">return</span> <span class="n">bboxes</span></div>
  560. <div class="viewcode-block" id="get_mosaic_coordinate"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.get_mosaic_coordinate">[docs]</a><span class="k">def</span> <span class="nf">get_mosaic_coordinate</span><span class="p">(</span><span class="n">mosaic_index</span><span class="p">,</span> <span class="n">xc</span><span class="p">,</span> <span class="n">yc</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">input_h</span><span class="p">,</span> <span class="n">input_w</span><span class="p">):</span>
  561. <span class="sd">&quot;&quot;&quot;</span>
  562. <span class="sd"> Returns the mosaic coordinates of final mosaic image according to mosaic image index.</span>
  563. <span class="sd"> :param mosaic_index: (int) mosaic image index</span>
  564. <span class="sd"> :param xc: (int) center x coordinate of the entire mosaic grid.</span>
  565. <span class="sd"> :param yc: (int) center y coordinate of the entire mosaic grid.</span>
  566. <span class="sd"> :param w: (int) width of bbox</span>
  567. <span class="sd"> :param h: (int) height of bbox</span>
  568. <span class="sd"> :param input_h: (int) image input height (should be 1/2 of the final mosaic output image height).</span>
  569. <span class="sd"> :param input_w: (int) image input width (should be 1/2 of the final mosaic output image width).</span>
  570. <span class="sd"> :return: (x1, y1, x2, y2), (x1s, y1s, x2s, y2s) where (x1, y1, x2, y2) are the coordinates in the final mosaic</span>
  571. <span class="sd"> output image, and (x1s, y1s, x2s, y2s) are the coordinates in the placed image.</span>
  572. <span class="sd"> &quot;&quot;&quot;</span>
  573. <span class="c1"># index0 to top left part of image</span>
  574. <span class="k">if</span> <span class="n">mosaic_index</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  575. <span class="n">x1</span><span class="p">,</span> <span class="n">y1</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">y2</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">xc</span> <span class="o">-</span> <span class="n">w</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="nb">max</span><span class="p">(</span><span class="n">yc</span> <span class="o">-</span> <span class="n">h</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="n">xc</span><span class="p">,</span> <span class="n">yc</span>
  576. <span class="n">small_coord</span> <span class="o">=</span> <span class="n">w</span> <span class="o">-</span> <span class="p">(</span><span class="n">x2</span> <span class="o">-</span> <span class="n">x1</span><span class="p">),</span> <span class="n">h</span> <span class="o">-</span> <span class="p">(</span><span class="n">y2</span> <span class="o">-</span> <span class="n">y1</span><span class="p">),</span> <span class="n">w</span><span class="p">,</span> <span class="n">h</span>
  577. <span class="c1"># index1 to top right part of image</span>
  578. <span class="k">elif</span> <span class="n">mosaic_index</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
  579. <span class="n">x1</span><span class="p">,</span> <span class="n">y1</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">y2</span> <span class="o">=</span> <span class="n">xc</span><span class="p">,</span> <span class="nb">max</span><span class="p">(</span><span class="n">yc</span> <span class="o">-</span> <span class="n">h</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="nb">min</span><span class="p">(</span><span class="n">xc</span> <span class="o">+</span> <span class="n">w</span><span class="p">,</span> <span class="n">input_w</span> <span class="o">*</span> <span class="mi">2</span><span class="p">),</span> <span class="n">yc</span>
  580. <span class="n">small_coord</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">h</span> <span class="o">-</span> <span class="p">(</span><span class="n">y2</span> <span class="o">-</span> <span class="n">y1</span><span class="p">),</span> <span class="nb">min</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">x2</span> <span class="o">-</span> <span class="n">x1</span><span class="p">),</span> <span class="n">h</span>
  581. <span class="c1"># index2 to bottom left part of image</span>
  582. <span class="k">elif</span> <span class="n">mosaic_index</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
  583. <span class="n">x1</span><span class="p">,</span> <span class="n">y1</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">y2</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">xc</span> <span class="o">-</span> <span class="n">w</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="n">yc</span><span class="p">,</span> <span class="n">xc</span><span class="p">,</span> <span class="nb">min</span><span class="p">(</span><span class="n">input_h</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">yc</span> <span class="o">+</span> <span class="n">h</span><span class="p">)</span>
  584. <span class="n">small_coord</span> <span class="o">=</span> <span class="n">w</span> <span class="o">-</span> <span class="p">(</span><span class="n">x2</span> <span class="o">-</span> <span class="n">x1</span><span class="p">),</span> <span class="mi">0</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="nb">min</span><span class="p">(</span><span class="n">y2</span> <span class="o">-</span> <span class="n">y1</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span>
  585. <span class="c1"># index2 to bottom right part of image</span>
  586. <span class="k">elif</span> <span class="n">mosaic_index</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
  587. <span class="n">x1</span><span class="p">,</span> <span class="n">y1</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">y2</span> <span class="o">=</span> <span class="n">xc</span><span class="p">,</span> <span class="n">yc</span><span class="p">,</span> <span class="nb">min</span><span class="p">(</span><span class="n">xc</span> <span class="o">+</span> <span class="n">w</span><span class="p">,</span> <span class="n">input_w</span> <span class="o">*</span> <span class="mi">2</span><span class="p">),</span> <span class="nb">min</span><span class="p">(</span><span class="n">input_h</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">yc</span> <span class="o">+</span> <span class="n">h</span><span class="p">)</span> <span class="c1"># noqa</span>
  588. <span class="n">small_coord</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">min</span><span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">x2</span> <span class="o">-</span> <span class="n">x1</span><span class="p">),</span> <span class="nb">min</span><span class="p">(</span><span class="n">y2</span> <span class="o">-</span> <span class="n">y1</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span>
  589. <span class="k">return</span> <span class="p">(</span><span class="n">x1</span><span class="p">,</span> <span class="n">y1</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">y2</span><span class="p">),</span> <span class="n">small_coord</span></div>
  590. <div class="viewcode-block" id="adjust_box_anns"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.adjust_box_anns">[docs]</a><span class="k">def</span> <span class="nf">adjust_box_anns</span><span class="p">(</span><span class="n">bbox</span><span class="p">,</span> <span class="n">scale_ratio</span><span class="p">,</span> <span class="n">padw</span><span class="p">,</span> <span class="n">padh</span><span class="p">,</span> <span class="n">w_max</span><span class="p">,</span> <span class="n">h_max</span><span class="p">):</span>
  591. <span class="sd">&quot;&quot;&quot;</span>
  592. <span class="sd"> Adjusts the bbox annotations of rescaled, padded image.</span>
  593. <span class="sd"> :param bbox: (np.array) bbox to modify.</span>
  594. <span class="sd"> :param scale_ratio: (float) scale ratio between rescale output image and original one.</span>
  595. <span class="sd"> :param padw: (int) width padding size.</span>
  596. <span class="sd"> :param padh: (int) height padding size.</span>
  597. <span class="sd"> :param w_max: (int) width border.</span>
  598. <span class="sd"> :param h_max: (int) height border</span>
  599. <span class="sd"> :return: modified bbox (np.array)</span>
  600. <span class="sd"> &quot;&quot;&quot;</span>
  601. <span class="n">bbox</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">bbox</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="n">scale_ratio</span> <span class="o">+</span> <span class="n">padw</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">w_max</span><span class="p">)</span>
  602. <span class="n">bbox</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">bbox</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="n">scale_ratio</span> <span class="o">+</span> <span class="n">padh</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">h_max</span><span class="p">)</span>
  603. <span class="k">return</span> <span class="n">bbox</span></div>
  604. <div class="viewcode-block" id="DetectionCollateFN"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.DetectionCollateFN">[docs]</a><span class="k">class</span> <span class="nc">DetectionCollateFN</span><span class="p">:</span>
  605. <span class="sd">&quot;&quot;&quot;</span>
  606. <span class="sd"> Collate function for Yolox training</span>
  607. <span class="sd"> &quot;&quot;&quot;</span>
  608. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]:</span>
  609. <span class="n">batch</span> <span class="o">=</span> <span class="n">default_collate</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
  610. <span class="n">ims</span><span class="p">,</span> <span class="n">targets</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">2</span><span class="p">]</span>
  611. <span class="k">return</span> <span class="n">ims</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_format_targets</span><span class="p">(</span><span class="n">targets</span><span class="p">)</span>
  612. <span class="k">def</span> <span class="nf">_format_targets</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">targets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
  613. <span class="n">nlabel</span> <span class="o">=</span> <span class="p">(</span><span class="n">targets</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># number of label per image</span>
  614. <span class="n">targets_merged</span> <span class="o">=</span> <span class="p">[]</span>
  615. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">targets</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
  616. <span class="n">targets_im</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="p">:</span><span class="n">nlabel</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span>
  617. <span class="n">batch_column</span> <span class="o">=</span> <span class="n">targets</span><span class="o">.</span><span class="n">new_ones</span><span class="p">((</span><span class="n">targets_im</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">1</span><span class="p">))</span> <span class="o">*</span> <span class="n">i</span>
  618. <span class="n">targets_merged</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">batch_column</span><span class="p">,</span> <span class="n">targets_im</span><span class="p">),</span> <span class="mi">1</span><span class="p">))</span>
  619. <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">targets_merged</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span></div>
  620. <div class="viewcode-block" id="CrowdDetectionCollateFN"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.CrowdDetectionCollateFN">[docs]</a><span class="k">class</span> <span class="nc">CrowdDetectionCollateFN</span><span class="p">(</span><span class="n">DetectionCollateFN</span><span class="p">):</span>
  621. <span class="sd">&quot;&quot;&quot;</span>
  622. <span class="sd"> Collate function for Yolox training with additional_batch_items that includes crowd targets</span>
  623. <span class="sd"> &quot;&quot;&quot;</span>
  624. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]]:</span>
  625. <span class="n">batch</span> <span class="o">=</span> <span class="n">default_collate</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
  626. <span class="n">ims</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">crowd_targets</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">3</span><span class="p">]</span>
  627. <span class="k">return</span> <span class="n">ims</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">_format_targets</span><span class="p">(</span><span class="n">targets</span><span class="p">),</span> <span class="p">{</span><span class="s2">&quot;crowd_targets&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_format_targets</span><span class="p">(</span><span class="n">crowd_targets</span><span class="p">)}</span></div>
  628. <div class="viewcode-block" id="compute_box_area"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.compute_box_area">[docs]</a><span class="k">def</span> <span class="nf">compute_box_area</span><span class="p">(</span><span class="n">box</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
  629. <span class="sd">&quot;&quot;&quot;Compute the area of one or many boxes.</span>
  630. <span class="sd"> :param box: One or many boxes, shape = (4, ?), each box in format (x1, y1, x2, y2)</span>
  631. <span class="sd"> Returns:</span>
  632. <span class="sd"> Area of every box, shape = (1, ?)</span>
  633. <span class="sd"> &quot;&quot;&quot;</span>
  634. <span class="c1"># box = 4xn</span>
  635. <span class="k">return</span> <span class="p">(</span><span class="n">box</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">box</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o">*</span> <span class="p">(</span><span class="n">box</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">-</span> <span class="n">box</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span></div>
  636. <div class="viewcode-block" id="crowd_ioa"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.crowd_ioa">[docs]</a><span class="k">def</span> <span class="nf">crowd_ioa</span><span class="p">(</span><span class="n">det_box</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">crowd_box</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
  637. <span class="sd">&quot;&quot;&quot;</span>
  638. <span class="sd"> Return intersection-over-detection_area of boxes, used for crowd ground truths.</span>
  639. <span class="sd"> Both sets of boxes are expected to be in (x1, y1, x2, y2) format.</span>
  640. <span class="sd"> Arguments:</span>
  641. <span class="sd"> det_box (Tensor[N, 4])</span>
  642. <span class="sd"> crowd_box (Tensor[M, 4])</span>
  643. <span class="sd"> Returns:</span>
  644. <span class="sd"> crowd_ioa (Tensor[N, M]): the NxM matrix containing the pairwise</span>
  645. <span class="sd"> IoA values for every element in det_box and crowd_box</span>
  646. <span class="sd"> &quot;&quot;&quot;</span>
  647. <span class="n">det_area</span> <span class="o">=</span> <span class="n">compute_box_area</span><span class="p">(</span><span class="n">det_box</span><span class="o">.</span><span class="n">T</span><span class="p">)</span>
  648. <span class="c1"># inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)</span>
  649. <span class="n">inter</span> <span class="o">=</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">det_box</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="mi">2</span><span class="p">:],</span> <span class="n">crowd_box</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:])</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">det_box</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:</span><span class="mi">2</span><span class="p">],</span> <span class="n">crowd_box</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]))</span> \
  650. <span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span>
  651. <span class="k">return</span> <span class="n">inter</span> <span class="o">/</span> <span class="n">det_area</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="c1"># crowd_ioa = inter / det_area</span></div>
  652. <div class="viewcode-block" id="compute_detection_matching"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.compute_detection_matching">[docs]</a><span class="k">def</span> <span class="nf">compute_detection_matching</span><span class="p">(</span>
  653. <span class="n">output</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
  654. <span class="n">targets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
  655. <span class="n">height</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
  656. <span class="n">width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
  657. <span class="n">iou_thresholds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
  658. <span class="n">denormalize_targets</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
  659. <span class="n">device</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
  660. <span class="n">crowd_targets</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
  661. <span class="n">top_k</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
  662. <span class="n">return_on_cpu</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
  663. <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">]:</span>
  664. <span class="sd">&quot;&quot;&quot;</span>
  665. <span class="sd"> Match predictions (NMS output) and the targets (ground truth) with respect to IoU and confidence score.</span>
  666. <span class="sd"> :param output: list (of length batch_size) of Tensors of shape (num_predictions, 6)</span>
  667. <span class="sd"> format: (x1, y1, x2, y2, confidence, class_label) where x1,y1,x2,y2 are according to image size</span>
  668. <span class="sd"> :param targets: targets for all images of shape (total_num_targets, 6)</span>
  669. <span class="sd"> format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]</span>
  670. <span class="sd"> :param height: dimensions of the image</span>
  671. <span class="sd"> :param width: dimensions of the image</span>
  672. <span class="sd"> :param iou_thresholds: Threshold to compute the mAP</span>
  673. <span class="sd"> :param device: Device</span>
  674. <span class="sd"> :param crowd_targets: crowd targets for all images of shape (total_num_crowd_targets, 6)</span>
  675. <span class="sd"> format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]</span>
  676. <span class="sd"> :param top_k: Number of predictions to keep per class, ordered by confidence score</span>
  677. <span class="sd"> :param denormalize_targets: If True, denormalize the targets and crowd_targets</span>
  678. <span class="sd"> :param return_on_cpu: If True, the output will be returned on &quot;CPU&quot;, otherwise it will be returned on &quot;device&quot;</span>
  679. <span class="sd"> :return: list of the following tensors, for every image:</span>
  680. <span class="sd"> :preds_matched: Tensor of shape (num_img_predictions, n_iou_thresholds)</span>
  681. <span class="sd"> True when prediction (i) is matched with a target with respect to the (j)th IoU threshold</span>
  682. <span class="sd"> :preds_to_ignore: Tensor of shape (num_img_predictions, n_iou_thresholds)</span>
  683. <span class="sd"> True when prediction (i) is matched with a crowd target with respect to the (j)th IoU threshold</span>
  684. <span class="sd"> :preds_scores: Tensor of shape (num_img_predictions), confidence score for every prediction</span>
  685. <span class="sd"> :preds_cls: Tensor of shape (num_img_predictions), predicted class for every prediction</span>
  686. <span class="sd"> :targets_cls: Tensor of shape (num_img_targets), ground truth class for every target</span>
  687. <span class="sd"> &quot;&quot;&quot;</span>
  688. <span class="n">output</span> <span class="o">=</span> <span class="nb">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">tensor</span><span class="p">:</span> <span class="kc">None</span> <span class="k">if</span> <span class="n">tensor</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">tensor</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">output</span><span class="p">)</span>
  689. <span class="n">targets</span><span class="p">,</span> <span class="n">iou_thresholds</span> <span class="o">=</span> <span class="n">targets</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">iou_thresholds</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
  690. <span class="c1"># If crowd_targets is not provided, we patch it with an empty tensor</span>
  691. <span class="n">crowd_targets</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">6</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span> <span class="k">if</span> <span class="n">crowd_targets</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">crowd_targets</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
  692. <span class="n">batch_metrics</span> <span class="o">=</span> <span class="p">[]</span>
  693. <span class="k">for</span> <span class="n">img_i</span><span class="p">,</span> <span class="n">img_preds</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">output</span><span class="p">):</span>
  694. <span class="c1"># If img_preds is None (not prediction for this image), we patch it with an empty tensor</span>
  695. <span class="n">img_preds</span> <span class="o">=</span> <span class="n">img_preds</span> <span class="k">if</span> <span class="n">img_preds</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">6</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
  696. <span class="n">img_targets</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="n">targets</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">img_i</span><span class="p">,</span> <span class="mi">1</span><span class="p">:]</span>
  697. <span class="n">img_crowd_targets</span> <span class="o">=</span> <span class="n">crowd_targets</span><span class="p">[</span><span class="n">crowd_targets</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">img_i</span><span class="p">,</span> <span class="mi">1</span><span class="p">:]</span>
  698. <span class="n">img_matching_tensors</span> <span class="o">=</span> <span class="n">compute_img_detection_matching</span><span class="p">(</span>
  699. <span class="n">preds</span><span class="o">=</span><span class="n">img_preds</span><span class="p">,</span>
  700. <span class="n">targets</span><span class="o">=</span><span class="n">img_targets</span><span class="p">,</span>
  701. <span class="n">crowd_targets</span><span class="o">=</span><span class="n">img_crowd_targets</span><span class="p">,</span>
  702. <span class="n">denormalize_targets</span><span class="o">=</span><span class="n">denormalize_targets</span><span class="p">,</span>
  703. <span class="n">height</span><span class="o">=</span><span class="n">height</span><span class="p">,</span>
  704. <span class="n">width</span><span class="o">=</span><span class="n">width</span><span class="p">,</span>
  705. <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
  706. <span class="n">iou_thresholds</span><span class="o">=</span><span class="n">iou_thresholds</span><span class="p">,</span>
  707. <span class="n">top_k</span><span class="o">=</span><span class="n">top_k</span><span class="p">,</span>
  708. <span class="n">return_on_cpu</span><span class="o">=</span><span class="n">return_on_cpu</span>
  709. <span class="p">)</span>
  710. <span class="n">batch_metrics</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">img_matching_tensors</span><span class="p">)</span>
  711. <span class="k">return</span> <span class="n">batch_metrics</span></div>
  712. <div class="viewcode-block" id="compute_img_detection_matching"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.compute_img_detection_matching">[docs]</a><span class="k">def</span> <span class="nf">compute_img_detection_matching</span><span class="p">(</span>
  713. <span class="n">preds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
  714. <span class="n">targets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
  715. <span class="n">crowd_targets</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
  716. <span class="n">height</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
  717. <span class="n">width</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
  718. <span class="n">iou_thresholds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
  719. <span class="n">device</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
  720. <span class="n">denormalize_targets</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
  721. <span class="n">top_k</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
  722. <span class="n">return_on_cpu</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
  723. <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">:</span>
  724. <span class="sd">&quot;&quot;&quot;</span>
  725. <span class="sd"> Match predictions (NMS output) and the targets (ground truth) with respect to IoU and confidence score</span>
  726. <span class="sd"> for a given image.</span>
  727. <span class="sd"> :param preds: Tensor of shape (num_img_predictions, 6)</span>
  728. <span class="sd"> format: (x1, y1, x2, y2, confidence, class_label) where x1,y1,x2,y2 are according to image size</span>
  729. <span class="sd"> :param targets: targets for this image of shape (num_img_targets, 6)</span>
  730. <span class="sd"> format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]</span>
  731. <span class="sd"> :param height: dimensions of the image</span>
  732. <span class="sd"> :param width: dimensions of the image</span>
  733. <span class="sd"> :param iou_thresholds: Threshold to compute the mAP</span>
  734. <span class="sd"> :param device:</span>
  735. <span class="sd"> :param crowd_targets: crowd targets for all images of shape (total_num_crowd_targets, 6)</span>
  736. <span class="sd"> format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]</span>
  737. <span class="sd"> :param top_k: Number of predictions to keep per class, ordered by confidence score</span>
  738. <span class="sd"> :param device: Device</span>
  739. <span class="sd"> :param denormalize_targets: If True, denormalize the targets and crowd_targets</span>
  740. <span class="sd"> :param return_on_cpu: If True, the output will be returned on &quot;CPU&quot;, otherwise it will be returned on &quot;device&quot;</span>
  741. <span class="sd"> :return:</span>
  742. <span class="sd"> :preds_matched: Tensor of shape (num_img_predictions, n_iou_thresholds)</span>
  743. <span class="sd"> True when prediction (i) is matched with a target with respect to the (j)th IoU threshold</span>
  744. <span class="sd"> :preds_to_ignore: Tensor of shape (num_img_predictions, n_iou_thresholds)</span>
  745. <span class="sd"> True when prediction (i) is matched with a crowd target with respect to the (j)th IoU threshold</span>
  746. <span class="sd"> :preds_scores: Tensor of shape (num_img_predictions), confidence score for every prediction</span>
  747. <span class="sd"> :preds_cls: Tensor of shape (num_img_predictions), predicted class for every prediction</span>
  748. <span class="sd"> :targets_cls: Tensor of shape (num_img_targets), ground truth class for every target</span>
  749. <span class="sd"> &quot;&quot;&quot;</span>
  750. <span class="n">num_iou_thresholds</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">iou_thresholds</span><span class="p">)</span>
  751. <span class="k">if</span> <span class="n">preds</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">preds</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  752. <span class="k">if</span> <span class="n">return_on_cpu</span><span class="p">:</span>
  753. <span class="n">device</span> <span class="o">=</span> <span class="s2">&quot;cpu&quot;</span>
  754. <span class="n">preds_matched</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="n">num_iou_thresholds</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
  755. <span class="n">preds_to_ignore</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="n">num_iou_thresholds</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
  756. <span class="n">preds_scores</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
  757. <span class="n">preds_cls</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
  758. <span class="n">targets_cls</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
  759. <span class="k">return</span> <span class="n">preds_matched</span><span class="p">,</span> <span class="n">preds_to_ignore</span><span class="p">,</span> <span class="n">preds_scores</span><span class="p">,</span> <span class="n">preds_cls</span><span class="p">,</span> <span class="n">targets_cls</span>
  760. <span class="n">preds_matched</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">preds</span><span class="p">),</span> <span class="n">num_iou_thresholds</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
  761. <span class="n">targets_matched</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">targets</span><span class="p">),</span> <span class="n">num_iou_thresholds</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
  762. <span class="n">preds_to_ignore</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">preds</span><span class="p">),</span> <span class="n">num_iou_thresholds</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">bool</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
  763. <span class="n">preds_cls</span><span class="p">,</span> <span class="n">preds_box</span><span class="p">,</span> <span class="n">preds_scores</span> <span class="o">=</span> <span class="n">preds</span><span class="p">[:,</span> <span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">preds</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">4</span><span class="p">],</span> <span class="n">preds</span><span class="p">[:,</span> <span class="mi">4</span><span class="p">]</span>
  764. <span class="n">targets_cls</span><span class="p">,</span> <span class="n">targets_box</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">targets</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">:</span><span class="mi">5</span><span class="p">]</span>
  765. <span class="n">crowd_targets_cls</span><span class="p">,</span> <span class="n">crowd_target_box</span> <span class="o">=</span> <span class="n">crowd_targets</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">crowd_targets</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">:</span><span class="mi">5</span><span class="p">]</span>
  766. <span class="c1"># Ignore all but the predictions that were top_k for their class</span>
  767. <span class="n">preds_idx_to_use</span> <span class="o">=</span> <span class="n">get_top_k_idx_per_cls</span><span class="p">(</span><span class="n">preds_scores</span><span class="p">,</span> <span class="n">preds_cls</span><span class="p">,</span> <span class="n">top_k</span><span class="p">)</span>
  768. <span class="n">preds_to_ignore</span><span class="p">[:,</span> <span class="p">:]</span> <span class="o">=</span> <span class="kc">True</span>
  769. <span class="n">preds_to_ignore</span><span class="p">[</span><span class="n">preds_idx_to_use</span><span class="p">]</span> <span class="o">=</span> <span class="kc">False</span>
  770. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">targets</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">crowd_targets</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
  771. <span class="c1"># CHANGE bboxes TO FIT THE IMAGE SIZE</span>
  772. <span class="n">change_bbox_bounds_for_image_size</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="p">(</span><span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">))</span>
  773. <span class="c1"># if target_format == &quot;xywh&quot;:</span>
  774. <span class="n">targets_box</span> <span class="o">=</span> <span class="n">convert_xywh_bbox_to_xyxy</span><span class="p">(</span><span class="n">targets_box</span><span class="p">)</span> <span class="c1"># cxcywh2xyxy</span>
  775. <span class="n">crowd_target_box</span> <span class="o">=</span> <span class="n">convert_xywh_bbox_to_xyxy</span><span class="p">(</span><span class="n">crowd_target_box</span><span class="p">)</span> <span class="c1"># convert_xywh_bbox_to_xyxy</span>
  776. <span class="k">if</span> <span class="n">denormalize_targets</span><span class="p">:</span>
  777. <span class="n">targets_box</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">]]</span> <span class="o">*=</span> <span class="n">width</span>
  778. <span class="n">targets_box</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">]]</span> <span class="o">*=</span> <span class="n">height</span>
  779. <span class="n">crowd_target_box</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">]]</span> <span class="o">*=</span> <span class="n">width</span>
  780. <span class="n">crowd_target_box</span><span class="p">[:,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">]]</span> <span class="o">*=</span> <span class="n">height</span>
  781. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">targets</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
  782. <span class="c1"># shape = (n_preds x n_targets)</span>
  783. <span class="n">iou</span> <span class="o">=</span> <span class="n">box_iou</span><span class="p">(</span><span class="n">preds_box</span><span class="p">[</span><span class="n">preds_idx_to_use</span><span class="p">],</span> <span class="n">targets_box</span><span class="p">)</span>
  784. <span class="c1"># Fill IoU values at index (i, j) with 0 when the prediction (i) and target(j) are of different class</span>
  785. <span class="c1"># Filling with 0 is equivalent to ignore these values since with want IoU &gt; iou_threshold &gt; 0</span>
  786. <span class="n">cls_mismatch</span> <span class="o">=</span> <span class="p">(</span><span class="n">preds_cls</span><span class="p">[</span><span class="n">preds_idx_to_use</span><span class="p">]</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">!=</span> <span class="n">targets_cls</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span>
  787. <span class="n">iou</span><span class="p">[</span><span class="n">cls_mismatch</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
  788. <span class="c1"># The matching priority is first detection confidence and then IoU value.</span>
  789. <span class="c1"># The detection is already sorted by confidence in NMS, so here for each prediction we order the targets by iou.</span>
  790. <span class="n">sorted_iou</span><span class="p">,</span> <span class="n">target_sorted</span> <span class="o">=</span> <span class="n">iou</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">descending</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">stable</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  791. <span class="c1"># Only iterate over IoU values higher than min threshold to speed up the process</span>
  792. <span class="k">for</span> <span class="n">pred_selected_i</span><span class="p">,</span> <span class="n">target_sorted_i</span> <span class="ow">in</span> <span class="p">(</span><span class="n">sorted_iou</span> <span class="o">&gt;</span> <span class="n">iou_thresholds</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">nonzero</span><span class="p">(</span><span class="n">as_tuple</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
  793. <span class="c1"># pred_selected_i and target_sorted_i are relative to filters/sorting, so we extract their absolute indexes</span>
  794. <span class="n">pred_i</span> <span class="o">=</span> <span class="n">preds_idx_to_use</span><span class="p">[</span><span class="n">pred_selected_i</span><span class="p">]</span>
  795. <span class="n">target_i</span> <span class="o">=</span> <span class="n">target_sorted</span><span class="p">[</span><span class="n">pred_selected_i</span><span class="p">,</span> <span class="n">target_sorted_i</span><span class="p">]</span>
  796. <span class="c1"># Vector[j], True when IoU(pred_i, target_i) is above the (j)th threshold</span>
  797. <span class="n">is_iou_above_threshold</span> <span class="o">=</span> <span class="n">sorted_iou</span><span class="p">[</span><span class="n">pred_selected_i</span><span class="p">,</span> <span class="n">target_sorted_i</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">iou_thresholds</span>
  798. <span class="c1"># Vector[j], True when both pred_i and target_i are not matched yet for the (j)th threshold</span>
  799. <span class="n">are_candidates_free</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">logical_and</span><span class="p">(</span><span class="o">~</span><span class="n">preds_matched</span><span class="p">[</span><span class="n">pred_i</span><span class="p">,</span> <span class="p">:],</span> <span class="o">~</span><span class="n">targets_matched</span><span class="p">[</span><span class="n">target_i</span><span class="p">,</span> <span class="p">:])</span>
  800. <span class="c1"># Vector[j], True when (pred_i, target_i) can be matched for the (j)th threshold</span>
  801. <span class="n">are_candidates_good</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">logical_and</span><span class="p">(</span><span class="n">is_iou_above_threshold</span><span class="p">,</span> <span class="n">are_candidates_free</span><span class="p">)</span>
  802. <span class="c1"># For every threshold (j) where target_i and pred_i can be matched together ( are_candidates_good[j]==True )</span>
  803. <span class="c1"># fill the matching placeholders with True</span>
  804. <span class="n">targets_matched</span><span class="p">[</span><span class="n">target_i</span><span class="p">,</span> <span class="n">are_candidates_good</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
  805. <span class="n">preds_matched</span><span class="p">[</span><span class="n">pred_i</span><span class="p">,</span> <span class="n">are_candidates_good</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
  806. <span class="c1"># When all the targets are matched with a prediction for every IoU Threshold, stop.</span>
  807. <span class="k">if</span> <span class="n">targets_matched</span><span class="o">.</span><span class="n">all</span><span class="p">():</span>
  808. <span class="k">break</span>
  809. <span class="c1"># Crowd targets can be matched with many predictions.</span>
  810. <span class="c1"># Therefore, for every prediction we just need to check if it has IoA large enough with any crowd target.</span>
  811. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">crowd_targets</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
  812. <span class="c1"># shape = (n_preds_to_use x n_crowd_targets)</span>
  813. <span class="n">ioa</span> <span class="o">=</span> <span class="n">crowd_ioa</span><span class="p">(</span><span class="n">preds_box</span><span class="p">[</span><span class="n">preds_idx_to_use</span><span class="p">],</span> <span class="n">crowd_target_box</span><span class="p">)</span>
  814. <span class="c1"># Fill IoA values at index (i, j) with 0 when the prediction (i) and target(j) are of different class</span>
  815. <span class="c1"># Filling with 0 is equivalent to ignore these values since with want IoA &gt; threshold &gt; 0</span>
  816. <span class="n">cls_mismatch</span> <span class="o">=</span> <span class="p">(</span><span class="n">preds_cls</span><span class="p">[</span><span class="n">preds_idx_to_use</span><span class="p">]</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">!=</span> <span class="n">crowd_targets_cls</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span>
  817. <span class="n">ioa</span><span class="p">[</span><span class="n">cls_mismatch</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
  818. <span class="c1"># For each prediction, we keep it&#39;s highest score with any crowd target (of same class)</span>
  819. <span class="c1"># shape = (n_preds_to_use)</span>
  820. <span class="n">best_ioa</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">ioa</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
  821. <span class="c1"># If a prediction has IoA higher than threshold (with any target of same class), then there is a match</span>
  822. <span class="c1"># shape = (n_preds_to_use x iou_thresholds)</span>
  823. <span class="n">is_matching_with_crowd</span> <span class="o">=</span> <span class="p">(</span><span class="n">best_ioa</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">&gt;</span> <span class="n">iou_thresholds</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span>
  824. <span class="n">preds_to_ignore</span><span class="p">[</span><span class="n">preds_idx_to_use</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">logical_or</span><span class="p">(</span><span class="n">preds_to_ignore</span><span class="p">[</span><span class="n">preds_idx_to_use</span><span class="p">],</span> <span class="n">is_matching_with_crowd</span><span class="p">)</span>
  825. <span class="k">if</span> <span class="n">return_on_cpu</span><span class="p">:</span>
  826. <span class="n">preds_matched</span> <span class="o">=</span> <span class="n">preds_matched</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">&quot;cpu&quot;</span><span class="p">)</span>
  827. <span class="n">preds_to_ignore</span> <span class="o">=</span> <span class="n">preds_to_ignore</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">&quot;cpu&quot;</span><span class="p">)</span>
  828. <span class="n">preds_scores</span> <span class="o">=</span> <span class="n">preds_scores</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">&quot;cpu&quot;</span><span class="p">)</span>
  829. <span class="n">preds_cls</span> <span class="o">=</span> <span class="n">preds_cls</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">&quot;cpu&quot;</span><span class="p">)</span>
  830. <span class="n">targets_cls</span> <span class="o">=</span> <span class="n">targets_cls</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">&quot;cpu&quot;</span><span class="p">)</span>
  831. <span class="k">return</span> <span class="n">preds_matched</span><span class="p">,</span> <span class="n">preds_to_ignore</span><span class="p">,</span> <span class="n">preds_scores</span><span class="p">,</span> <span class="n">preds_cls</span><span class="p">,</span> <span class="n">targets_cls</span></div>
  832. <div class="viewcode-block" id="get_top_k_idx_per_cls"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.get_top_k_idx_per_cls">[docs]</a><span class="k">def</span> <span class="nf">get_top_k_idx_per_cls</span><span class="p">(</span><span class="n">preds_scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">preds_cls</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">top_k</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
  833. <span class="sd">&quot;&quot;&quot;Get the indexes of all the top k predictions for every class</span>
  834. <span class="sd"> :param preds_scores: The confidence scores, vector of shape (n_pred)</span>
  835. <span class="sd"> :param preds_cls: The predicted class, vector of shape (n_pred)</span>
  836. <span class="sd"> :param top_k: Number of predictions to keep per class, ordered by confidence score</span>
  837. <span class="sd"> :return top_k_idx: Indexes of the top k predictions. length &lt;= (k * n_unique_class)</span>
  838. <span class="sd"> &quot;&quot;&quot;</span>
  839. <span class="n">n_unique_cls</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">preds_cls</span><span class="p">)</span>
  840. <span class="n">mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">preds_cls</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">n_unique_cls</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">preds_scores</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span>
  841. <span class="n">preds_scores_per_cls</span> <span class="o">=</span> <span class="n">preds_scores</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">mask</span>
  842. <span class="n">sorted_scores_per_cls</span><span class="p">,</span> <span class="n">sorting_idx</span> <span class="o">=</span> <span class="n">preds_scores_per_cls</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">descending</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  843. <span class="n">idx_with_satisfying_scores</span> <span class="o">=</span> <span class="n">sorted_scores_per_cls</span><span class="p">[:</span><span class="n">top_k</span><span class="p">,</span> <span class="p">:]</span><span class="o">.</span><span class="n">nonzero</span><span class="p">(</span><span class="n">as_tuple</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
  844. <span class="n">top_k_idx</span> <span class="o">=</span> <span class="n">sorting_idx</span><span class="p">[</span><span class="n">idx_with_satisfying_scores</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)]</span>
  845. <span class="k">return</span> <span class="n">top_k_idx</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span></div>
  846. <div class="viewcode-block" id="compute_detection_metrics"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.compute_detection_metrics">[docs]</a><span class="k">def</span> <span class="nf">compute_detection_metrics</span><span class="p">(</span>
  847. <span class="n">preds_matched</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
  848. <span class="n">preds_to_ignore</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
  849. <span class="n">preds_scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
  850. <span class="n">preds_cls</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
  851. <span class="n">targets_cls</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
  852. <span class="n">device</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
  853. <span class="n">recall_thresholds</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
  854. <span class="n">score_threshold</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
  855. <span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">:</span>
  856. <span class="sd">&quot;&quot;&quot;</span>
  857. <span class="sd"> Compute the list of precision, recall, MaP and f1 for every recall IoU threshold and for every class.</span>
  858. <span class="sd"> :param preds_matched: Tensor of shape (num_predictions, n_iou_thresholds)</span>
  859. <span class="sd"> True when prediction (i) is matched with a target with respect to the (j)th IoU threshold</span>
  860. <span class="sd"> :param preds_to_ignore Tensor of shape (num_predictions, n_iou_thresholds)</span>
  861. <span class="sd"> True when prediction (i) is matched with a crowd target with respect to the (j)th IoU threshold</span>
  862. <span class="sd"> :param preds_scores: Tensor of shape (num_predictions), confidence score for every prediction</span>
  863. <span class="sd"> :param preds_cls: Tensor of shape (num_predictions), predicted class for every prediction</span>
  864. <span class="sd"> :param targets_cls: Tensor of shape (num_targets), ground truth class for every target box to be detected</span>
  865. <span class="sd"> :param recall_thresholds: Recall thresholds used to compute MaP.</span>
  866. <span class="sd"> :param score_threshold: Minimum confidence score to consider a prediction for the computation of</span>
  867. <span class="sd"> precision, recall and f1 (not MaP)</span>
  868. <span class="sd"> :param device: Device</span>
  869. <span class="sd"> :return:</span>
  870. <span class="sd"> :ap, precision, recall, f1: Tensors of shape (n_class, nb_iou_thrs)</span>
  871. <span class="sd"> :unique_classes: Vector with all unique target classes</span>
  872. <span class="sd"> &quot;&quot;&quot;</span>
  873. <span class="n">preds_matched</span><span class="p">,</span> <span class="n">preds_to_ignore</span> <span class="o">=</span> <span class="n">preds_matched</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">preds_to_ignore</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
  874. <span class="n">preds_scores</span><span class="p">,</span> <span class="n">preds_cls</span><span class="p">,</span> <span class="n">targets_cls</span> <span class="o">=</span> <span class="n">preds_scores</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">preds_cls</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">targets_cls</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
  875. <span class="n">recall_thresholds</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">101</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span> <span class="k">if</span> <span class="n">recall_thresholds</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">recall_thresholds</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
  876. <span class="n">unique_classes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">unique</span><span class="p">(</span><span class="n">targets_cls</span><span class="p">)</span>
  877. <span class="n">n_class</span><span class="p">,</span> <span class="n">nb_iou_thrs</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">unique_classes</span><span class="p">),</span> <span class="n">preds_matched</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
  878. <span class="n">ap</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">n_class</span><span class="p">,</span> <span class="n">nb_iou_thrs</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
  879. <span class="n">precision</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">n_class</span><span class="p">,</span> <span class="n">nb_iou_thrs</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
  880. <span class="n">recall</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">n_class</span><span class="p">,</span> <span class="n">nb_iou_thrs</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
  881. <span class="k">for</span> <span class="n">cls_i</span><span class="p">,</span> <span class="bp">cls</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">unique_classes</span><span class="p">):</span>
  882. <span class="n">cls_preds_idx</span><span class="p">,</span> <span class="n">cls_targets_idx</span> <span class="o">=</span> <span class="p">(</span><span class="n">preds_cls</span> <span class="o">==</span> <span class="bp">cls</span><span class="p">),</span> <span class="p">(</span><span class="n">targets_cls</span> <span class="o">==</span> <span class="bp">cls</span><span class="p">)</span>
  883. <span class="n">cls_ap</span><span class="p">,</span> <span class="n">cls_precision</span><span class="p">,</span> <span class="n">cls_recall</span> <span class="o">=</span> <span class="n">compute_detection_metrics_per_cls</span><span class="p">(</span>
  884. <span class="n">preds_matched</span><span class="o">=</span><span class="n">preds_matched</span><span class="p">[</span><span class="n">cls_preds_idx</span><span class="p">],</span>
  885. <span class="n">preds_to_ignore</span><span class="o">=</span><span class="n">preds_to_ignore</span><span class="p">[</span><span class="n">cls_preds_idx</span><span class="p">],</span>
  886. <span class="n">preds_scores</span><span class="o">=</span><span class="n">preds_scores</span><span class="p">[</span><span class="n">cls_preds_idx</span><span class="p">],</span>
  887. <span class="n">n_targets</span><span class="o">=</span><span class="n">cls_targets_idx</span><span class="o">.</span><span class="n">sum</span><span class="p">(),</span>
  888. <span class="n">recall_thresholds</span><span class="o">=</span><span class="n">recall_thresholds</span><span class="p">,</span>
  889. <span class="n">score_threshold</span><span class="o">=</span><span class="n">score_threshold</span><span class="p">,</span>
  890. <span class="n">device</span><span class="o">=</span><span class="n">device</span>
  891. <span class="p">)</span>
  892. <span class="n">ap</span><span class="p">[</span><span class="n">cls_i</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">cls_ap</span>
  893. <span class="n">precision</span><span class="p">[</span><span class="n">cls_i</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">cls_precision</span>
  894. <span class="n">recall</span><span class="p">[</span><span class="n">cls_i</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">cls_recall</span>
  895. <span class="n">f1</span> <span class="o">=</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">precision</span> <span class="o">*</span> <span class="n">recall</span> <span class="o">/</span> <span class="p">(</span><span class="n">precision</span> <span class="o">+</span> <span class="n">recall</span> <span class="o">+</span> <span class="mf">1e-16</span><span class="p">)</span>
  896. <span class="k">return</span> <span class="n">ap</span><span class="p">,</span> <span class="n">precision</span><span class="p">,</span> <span class="n">recall</span><span class="p">,</span> <span class="n">f1</span><span class="p">,</span> <span class="n">unique_classes</span></div>
  897. <div class="viewcode-block" id="compute_detection_metrics_per_cls"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.detection_utils.compute_detection_metrics_per_cls">[docs]</a><span class="k">def</span> <span class="nf">compute_detection_metrics_per_cls</span><span class="p">(</span>
  898. <span class="n">preds_matched</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
  899. <span class="n">preds_to_ignore</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
  900. <span class="n">preds_scores</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
  901. <span class="n">n_targets</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
  902. <span class="n">recall_thresholds</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
  903. <span class="n">score_threshold</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
  904. <span class="n">device</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
  905. <span class="p">):</span>
  906. <span class="sd">&quot;&quot;&quot;</span>
  907. <span class="sd"> Compute the list of precision, recall and MaP of a given class for every recall IoU threshold.</span>
  908. <span class="sd"> :param preds_matched: Tensor of shape (num_predictions, n_iou_thresholds)</span>
  909. <span class="sd"> True when prediction (i) is matched with a target</span>
  910. <span class="sd"> with respect to the(j)th IoU threshold</span>
  911. <span class="sd"> :param preds_to_ignore Tensor of shape (num_predictions, n_iou_thresholds)</span>
  912. <span class="sd"> True when prediction (i) is matched with a crowd target</span>
  913. <span class="sd"> with respect to the (j)th IoU threshold</span>
  914. <span class="sd"> :param preds_scores: Tensor of shape (num_predictions), confidence score for every prediction</span>
  915. <span class="sd"> :param n_targets: Number of target boxes of this class</span>
  916. <span class="sd"> :param recall_thresholds: Tensor of shape (max_n_rec_thresh) list of recall thresholds used to compute MaP</span>
  917. <span class="sd"> :param score_threshold: Minimum confidence score to consider a prediction for the computation of</span>
  918. <span class="sd"> precision and recall (not MaP)</span>
  919. <span class="sd"> :param device: Device</span>
  920. <span class="sd"> :return ap, precision, recall: Tensors of shape (nb_iou_thrs)</span>
  921. <span class="sd"> &quot;&quot;&quot;</span>
  922. <span class="n">nb_iou_thrs</span> <span class="o">=</span> <span class="n">preds_matched</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
  923. <span class="n">tps</span> <span class="o">=</span> <span class="n">preds_matched</span>
  924. <span class="n">fps</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">logical_and</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">logical_not</span><span class="p">(</span><span class="n">preds_matched</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">logical_not</span><span class="p">(</span><span class="n">preds_to_ignore</span><span class="p">))</span>
  925. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">tps</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  926. <span class="k">return</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">nb_iou_thrs</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
  927. <span class="c1"># Sort by decreasing score</span>
  928. <span class="n">dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">uint8</span> <span class="k">if</span> <span class="n">preds_scores</span><span class="o">.</span><span class="n">is_cuda</span> <span class="ow">and</span> <span class="n">preds_scores</span><span class="o">.</span><span class="n">dtype</span> <span class="ow">is</span> <span class="n">torch</span><span class="o">.</span><span class="n">bool</span> <span class="k">else</span> <span class="n">preds_scores</span><span class="o">.</span><span class="n">dtype</span>
  929. <span class="n">sort_ind</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">preds_scores</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dtype</span><span class="p">),</span> <span class="n">descending</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  930. <span class="n">tps</span> <span class="o">=</span> <span class="n">tps</span><span class="p">[</span><span class="n">sort_ind</span><span class="p">,</span> <span class="p">:]</span>
  931. <span class="n">fps</span> <span class="o">=</span> <span class="n">fps</span><span class="p">[</span><span class="n">sort_ind</span><span class="p">,</span> <span class="p">:]</span>
  932. <span class="n">preds_scores</span> <span class="o">=</span> <span class="n">preds_scores</span><span class="p">[</span><span class="n">sort_ind</span><span class="p">]</span>
  933. <span class="c1"># Rolling sum over the predictions</span>
  934. <span class="n">rolling_tps</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">tps</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span>
  935. <span class="n">rolling_fps</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">fps</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span>
  936. <span class="n">rolling_recalls</span> <span class="o">=</span> <span class="n">rolling_tps</span> <span class="o">/</span> <span class="n">n_targets</span>
  937. <span class="n">rolling_precisions</span> <span class="o">=</span> <span class="n">rolling_tps</span> <span class="o">/</span> <span class="p">(</span><span class="n">rolling_tps</span> <span class="o">+</span> <span class="n">rolling_fps</span> <span class="o">+</span> <span class="n">torch</span><span class="o">.</span><span class="n">finfo</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float64</span><span class="p">)</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span>
  938. <span class="c1"># Reversed cummax to only have decreasing values</span>
  939. <span class="n">rolling_precisions</span> <span class="o">=</span> <span class="n">rolling_precisions</span><span class="o">.</span><span class="n">flip</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">cummax</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">flip</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
  940. <span class="c1"># ==================</span>
  941. <span class="c1"># RECALL &amp; PRECISION</span>
  942. <span class="c1"># We want the rolling precision/recall at index i so that: preds_scores[i-1] &gt;= score_threshold &gt; preds_scores[i]</span>
  943. <span class="c1"># Note: torch.searchsorted works on increasing sequence and preds_scores is decreasing, so we work with &quot;-&quot;</span>
  944. <span class="n">lowest_score_above_threshold</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">searchsorted</span><span class="p">(</span><span class="o">-</span><span class="n">preds_scores</span><span class="p">,</span> <span class="o">-</span><span class="n">score_threshold</span><span class="p">,</span> <span class="n">right</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
  945. <span class="k">if</span> <span class="n">lowest_score_above_threshold</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="c1"># Here score_threshold &gt; preds_scores[0], so no pred is above the threshold</span>
  946. <span class="n">recall</span> <span class="o">=</span> <span class="mi">0</span>
  947. <span class="n">precision</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># the precision is not really defined when no pred but we need to give it a value</span>
  948. <span class="k">else</span><span class="p">:</span>
  949. <span class="n">recall</span> <span class="o">=</span> <span class="n">rolling_recalls</span><span class="p">[</span><span class="n">lowest_score_above_threshold</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span>
  950. <span class="n">precision</span> <span class="o">=</span> <span class="n">rolling_precisions</span><span class="p">[</span><span class="n">lowest_score_above_threshold</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span>
  951. <span class="c1"># ==================</span>
  952. <span class="c1"># AVERAGE PRECISION</span>
  953. <span class="c1"># shape = (nb_iou_thrs, n_recall_thresholds)</span>
  954. <span class="n">recall_thresholds</span> <span class="o">=</span> <span class="n">recall_thresholds</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">nb_iou_thrs</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
  955. <span class="c1"># We want the index i so that: rolling_recalls[i-1] &lt; recall_thresholds[k] &lt;= rolling_recalls[i]</span>
  956. <span class="c1"># Note: when recall_thresholds[k] &gt; max(rolling_recalls), i = len(rolling_recalls)</span>
  957. <span class="c1"># Note2: we work with transpose (.T) to apply torch.searchsorted on first dim instead of the last one</span>
  958. <span class="n">recall_threshold_idx</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">searchsorted</span><span class="p">(</span><span class="n">rolling_recalls</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">recall_thresholds</span><span class="p">,</span> <span class="n">right</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span><span class="o">.</span><span class="n">T</span>
  959. <span class="c1"># When recall_thresholds[k] &gt; max(rolling_recalls), rolling_precisions[i] is not defined, and we want precision = 0</span>
  960. <span class="n">rolling_precisions</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">rolling_precisions</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">nb_iou_thrs</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
  961. <span class="c1"># shape = (n_recall_thresholds, nb_iou_thrs)</span>
  962. <span class="n">sampled_precision_points</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span><span class="nb">input</span><span class="o">=</span><span class="n">rolling_precisions</span><span class="p">,</span> <span class="n">index</span><span class="o">=</span><span class="n">recall_threshold_idx</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
  963. <span class="c1"># Average over the recall_thresholds</span>
  964. <span class="n">ap</span> <span class="o">=</span> <span class="n">sampled_precision_points</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
  965. <span class="k">return</span> <span class="n">ap</span><span class="p">,</span> <span class="n">precision</span><span class="p">,</span> <span class="n">recall</span></div>
  966. </pre></div>
  967. </div>
  968. </div>
  969. <footer>
  970. <hr/>
  971. <div role="contentinfo">
  972. <p>&#169; Copyright 2021, SuperGradients team.</p>
  973. </div>
  974. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  975. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  976. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  977. </footer>
  978. </div>
  979. </div>
  980. </section>
  981. </div>
  982. <script>
  983. jQuery(function () {
  984. SphinxRtdTheme.Navigation.enable(true);
  985. });
  986. </script>
  987. </body>
  988. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.utils.distributed_training_utils &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.utils.distributed_training_utils</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.utils.distributed_training_utils</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">torch</span>
  84. <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">distributed</span> <span class="k">as</span> <span class="n">dist</span>
  85. <span class="kn">from</span> <span class="nn">torch.cuda.amp</span> <span class="kn">import</span> <span class="n">autocast</span>
  86. <span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
  87. <span class="kn">import</span> <span class="nn">itertools</span>
  88. <span class="kn">from</span> <span class="nn">contextlib</span> <span class="kn">import</span> <span class="n">contextmanager</span>
  89. <div class="viewcode-block" id="distributed_all_reduce_tensor_average"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.distributed_training_utils.distributed_all_reduce_tensor_average">[docs]</a><span class="k">def</span> <span class="nf">distributed_all_reduce_tensor_average</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">n</span><span class="p">):</span>
  90. <span class="sd">&quot;&quot;&quot;</span>
  91. <span class="sd"> This method performs a reduce operation on multiple nodes running distributed training</span>
  92. <span class="sd"> It first sums all of the results and then divides the summation</span>
  93. <span class="sd"> :param tensor: The tensor to perform the reduce operation for</span>
  94. <span class="sd"> :param n: Number of nodes</span>
  95. <span class="sd"> :return: Averaged tensor from all of the nodes</span>
  96. <span class="sd"> &quot;&quot;&quot;</span>
  97. <span class="n">rt</span> <span class="o">=</span> <span class="n">tensor</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
  98. <span class="n">torch</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">all_reduce</span><span class="p">(</span><span class="n">rt</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">ReduceOp</span><span class="o">.</span><span class="n">SUM</span><span class="p">)</span>
  99. <span class="n">rt</span> <span class="o">/=</span> <span class="n">n</span>
  100. <span class="k">return</span> <span class="n">rt</span></div>
  101. <div class="viewcode-block" id="reduce_results_tuple_for_ddp"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.distributed_training_utils.reduce_results_tuple_for_ddp">[docs]</a><span class="k">def</span> <span class="nf">reduce_results_tuple_for_ddp</span><span class="p">(</span><span class="n">validation_results_tuple</span><span class="p">,</span> <span class="n">device</span><span class="p">):</span>
  102. <span class="sd">&quot;&quot;&quot;Gather all validation tuples from the various devices and average them&quot;&quot;&quot;</span>
  103. <span class="n">validation_results_list</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">validation_results_tuple</span><span class="p">)</span>
  104. <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">validation_result</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">validation_results_list</span><span class="p">):</span>
  105. <span class="n">validation_results_list</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">distributed_all_reduce_tensor_average</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">validation_result</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span>
  106. <span class="n">torch</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">get_world_size</span><span class="p">())</span>
  107. <span class="n">validation_results_tuple</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">validation_results_list</span><span class="p">)</span>
  108. <span class="k">return</span> <span class="n">validation_results_tuple</span></div>
  109. <div class="viewcode-block" id="MultiGPUModeAutocastWrapper"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.distributed_training_utils.MultiGPUModeAutocastWrapper">[docs]</a><span class="k">class</span> <span class="nc">MultiGPUModeAutocastWrapper</span><span class="p">():</span>
  110. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">func</span><span class="p">):</span>
  111. <span class="bp">self</span><span class="o">.</span><span class="n">func</span> <span class="o">=</span> <span class="n">func</span>
  112. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  113. <span class="k">with</span> <span class="n">autocast</span><span class="p">():</span>
  114. <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">func</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
  115. <span class="k">return</span> <span class="n">out</span></div>
  116. <div class="viewcode-block" id="scaled_all_reduce"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.distributed_training_utils.scaled_all_reduce">[docs]</a><span class="k">def</span> <span class="nf">scaled_all_reduce</span><span class="p">(</span><span class="n">tensors</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">num_gpus</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
  117. <span class="sd">&quot;&quot;&quot;</span>
  118. <span class="sd"> Performs the scaled all_reduce operation on the provided tensors.</span>
  119. <span class="sd"> The input tensors are modified in-place.</span>
  120. <span class="sd"> Currently supports only the sum</span>
  121. <span class="sd"> reduction operator.</span>
  122. <span class="sd"> The reduced values are scaled by the inverse size of the</span>
  123. <span class="sd"> process group (equivalent to num_gpus).</span>
  124. <span class="sd"> &quot;&quot;&quot;</span>
  125. <span class="c1"># There is no need for reduction in the single-proc case</span>
  126. <span class="k">if</span> <span class="n">num_gpus</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
  127. <span class="k">return</span> <span class="n">tensors</span>
  128. <span class="c1"># Queue the reductions</span>
  129. <span class="n">reductions</span> <span class="o">=</span> <span class="p">[]</span>
  130. <span class="k">for</span> <span class="n">tensor</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">:</span>
  131. <span class="n">reduction</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">all_reduce</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">async_op</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  132. <span class="n">reductions</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">reduction</span><span class="p">)</span>
  133. <span class="c1"># Wait for reductions to finish</span>
  134. <span class="k">for</span> <span class="n">reduction</span> <span class="ow">in</span> <span class="n">reductions</span><span class="p">:</span>
  135. <span class="n">reduction</span><span class="o">.</span><span class="n">wait</span><span class="p">()</span>
  136. <span class="c1"># Scale the results</span>
  137. <span class="k">for</span> <span class="n">tensor</span> <span class="ow">in</span> <span class="n">tensors</span><span class="p">:</span>
  138. <span class="n">tensor</span><span class="o">.</span><span class="n">mul_</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">/</span> <span class="n">num_gpus</span><span class="p">)</span>
  139. <span class="k">return</span> <span class="n">tensors</span></div>
  140. <div class="viewcode-block" id="compute_precise_bn_stats"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.distributed_training_utils.compute_precise_bn_stats">[docs]</a><span class="nd">@torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">()</span>
  141. <span class="k">def</span> <span class="nf">compute_precise_bn_stats</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">loader</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">,</span> <span class="n">precise_bn_batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">num_gpus</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
  142. <span class="sd">&#39;&#39;&#39;</span>
  143. <span class="sd"> :param model: The model being trained (ie: SgModel.net)</span>
  144. <span class="sd"> :param loader: Training dataloader (ie: SgModel.train_loader)</span>
  145. <span class="sd"> :param precise_bn_batch_size: The effective batch size we want to calculate the batchnorm on. For example, if we are training a model</span>
  146. <span class="sd"> on 8 gpus, with a batch of 128 on each gpu, a good rule of thumb would be to give it 8192</span>
  147. <span class="sd"> (ie: effective_batch_size * num_gpus = batch_per_gpu * num_gpus * num_gpus).</span>
  148. <span class="sd"> If precise_bn_batch_size is not provided in the training_params, the latter heuristic</span>
  149. <span class="sd"> will be taken.</span>
  150. <span class="sd"> param num_gpus: The number of gpus we are training on</span>
  151. <span class="sd"> &#39;&#39;&#39;</span>
  152. <span class="c1"># Compute the number of minibatches to use</span>
  153. <span class="n">num_iter</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">precise_bn_batch_size</span> <span class="o">/</span> <span class="p">(</span><span class="n">loader</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">*</span> <span class="n">num_gpus</span><span class="p">))</span> <span class="k">if</span> <span class="n">precise_bn_batch_size</span> <span class="k">else</span> <span class="n">num_gpus</span>
  154. <span class="n">num_iter</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">num_iter</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">loader</span><span class="p">))</span>
  155. <span class="c1"># Retrieve the BN layers</span>
  156. <span class="n">bns</span> <span class="o">=</span> <span class="p">[</span><span class="n">m</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">modules</span><span class="p">()</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">)]</span>
  157. <span class="c1"># Initialize BN stats storage for computing mean(mean(batch)) and mean(var(batch))</span>
  158. <span class="n">running_means</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">bn</span><span class="o">.</span><span class="n">running_mean</span><span class="p">)</span> <span class="k">for</span> <span class="n">bn</span> <span class="ow">in</span> <span class="n">bns</span><span class="p">]</span>
  159. <span class="n">running_vars</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">bn</span><span class="o">.</span><span class="n">running_var</span><span class="p">)</span> <span class="k">for</span> <span class="n">bn</span> <span class="ow">in</span> <span class="n">bns</span><span class="p">]</span>
  160. <span class="c1"># Remember momentum values</span>
  161. <span class="n">momentums</span> <span class="o">=</span> <span class="p">[</span><span class="n">bn</span><span class="o">.</span><span class="n">momentum</span> <span class="k">for</span> <span class="n">bn</span> <span class="ow">in</span> <span class="n">bns</span><span class="p">]</span>
  162. <span class="c1"># Set momentum to 1.0 to compute BN stats that only reflect the current batch</span>
  163. <span class="k">for</span> <span class="n">bn</span> <span class="ow">in</span> <span class="n">bns</span><span class="p">:</span>
  164. <span class="n">bn</span><span class="o">.</span><span class="n">momentum</span> <span class="o">=</span> <span class="mf">1.0</span>
  165. <span class="c1"># Average the BN stats for each BN layer over the batches</span>
  166. <span class="k">for</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">_labels</span> <span class="ow">in</span> <span class="n">itertools</span><span class="o">.</span><span class="n">islice</span><span class="p">(</span><span class="n">loader</span><span class="p">,</span> <span class="n">num_iter</span><span class="p">):</span>
  167. <span class="n">model</span><span class="p">(</span><span class="n">inputs</span><span class="o">.</span><span class="n">cuda</span><span class="p">())</span>
  168. <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">bn</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">bns</span><span class="p">):</span>
  169. <span class="n">running_means</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+=</span> <span class="n">bn</span><span class="o">.</span><span class="n">running_mean</span> <span class="o">/</span> <span class="n">num_iter</span>
  170. <span class="n">running_vars</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+=</span> <span class="n">bn</span><span class="o">.</span><span class="n">running_var</span> <span class="o">/</span> <span class="n">num_iter</span>
  171. <span class="c1"># Sync BN stats across GPUs (no reduction if 1 GPU used)</span>
  172. <span class="n">running_means</span> <span class="o">=</span> <span class="n">scaled_all_reduce</span><span class="p">(</span><span class="n">running_means</span><span class="p">,</span> <span class="n">num_gpus</span><span class="o">=</span><span class="n">num_gpus</span><span class="p">)</span>
  173. <span class="n">running_vars</span> <span class="o">=</span> <span class="n">scaled_all_reduce</span><span class="p">(</span><span class="n">running_vars</span><span class="p">,</span> <span class="n">num_gpus</span><span class="o">=</span><span class="n">num_gpus</span><span class="p">)</span>
  174. <span class="c1"># Set BN stats and restore original momentum values</span>
  175. <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">bn</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">bns</span><span class="p">):</span>
  176. <span class="n">bn</span><span class="o">.</span><span class="n">running_mean</span> <span class="o">=</span> <span class="n">running_means</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
  177. <span class="n">bn</span><span class="o">.</span><span class="n">running_var</span> <span class="o">=</span> <span class="n">running_vars</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
  178. <span class="n">bn</span><span class="o">.</span><span class="n">momentum</span> <span class="o">=</span> <span class="n">momentums</span><span class="p">[</span><span class="n">i</span><span class="p">]</span></div>
  179. <div class="viewcode-block" id="get_local_rank"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.distributed_training_utils.get_local_rank">[docs]</a><span class="k">def</span> <span class="nf">get_local_rank</span><span class="p">():</span>
  180. <span class="sd">&quot;&quot;&quot;</span>
  181. <span class="sd"> Returns the local rank if running in DDP, and 0 otherwise</span>
  182. <span class="sd"> :return: local rank</span>
  183. <span class="sd"> &quot;&quot;&quot;</span>
  184. <span class="k">return</span> <span class="n">dist</span><span class="o">.</span><span class="n">get_rank</span><span class="p">()</span> <span class="k">if</span> <span class="n">dist</span><span class="o">.</span><span class="n">is_initialized</span><span class="p">()</span> <span class="k">else</span> <span class="mi">0</span></div>
  185. <div class="viewcode-block" id="get_world_size"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.distributed_training_utils.get_world_size">[docs]</a><span class="k">def</span> <span class="nf">get_world_size</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
  186. <span class="sd">&quot;&quot;&quot;</span>
  187. <span class="sd"> Returns the world size if running in DDP, and 1 otherwise</span>
  188. <span class="sd"> :return: world size</span>
  189. <span class="sd"> &quot;&quot;&quot;</span>
  190. <span class="k">if</span> <span class="ow">not</span> <span class="n">dist</span><span class="o">.</span><span class="n">is_available</span><span class="p">():</span>
  191. <span class="k">return</span> <span class="mi">1</span>
  192. <span class="k">if</span> <span class="ow">not</span> <span class="n">dist</span><span class="o">.</span><span class="n">is_initialized</span><span class="p">():</span>
  193. <span class="k">return</span> <span class="mi">1</span>
  194. <span class="k">return</span> <span class="n">dist</span><span class="o">.</span><span class="n">get_world_size</span><span class="p">()</span></div>
  195. <div class="viewcode-block" id="wait_for_the_master"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.distributed_training_utils.wait_for_the_master">[docs]</a><span class="nd">@contextmanager</span>
  196. <span class="k">def</span> <span class="nf">wait_for_the_master</span><span class="p">(</span><span class="n">local_rank</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
  197. <span class="sd">&quot;&quot;&quot;</span>
  198. <span class="sd"> Make all processes waiting for the master to do some task.</span>
  199. <span class="sd"> &quot;&quot;&quot;</span>
  200. <span class="k">if</span> <span class="n">local_rank</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
  201. <span class="n">dist</span><span class="o">.</span><span class="n">barrier</span><span class="p">()</span>
  202. <span class="k">yield</span>
  203. <span class="k">if</span> <span class="n">local_rank</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  204. <span class="k">if</span> <span class="ow">not</span> <span class="n">dist</span><span class="o">.</span><span class="n">is_available</span><span class="p">():</span>
  205. <span class="k">return</span>
  206. <span class="k">if</span> <span class="ow">not</span> <span class="n">dist</span><span class="o">.</span><span class="n">is_initialized</span><span class="p">():</span>
  207. <span class="k">return</span>
  208. <span class="k">else</span><span class="p">:</span>
  209. <span class="n">dist</span><span class="o">.</span><span class="n">barrier</span><span class="p">()</span></div>
  210. </pre></div>
  211. </div>
  212. </div>
  213. <footer>
  214. <hr/>
  215. <div role="contentinfo">
  216. <p>&#169; Copyright 2021, SuperGradients team.</p>
  217. </div>
  218. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  219. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  220. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  221. </footer>
  222. </div>
  223. </div>
  224. </section>
  225. </div>
  226. <script>
  227. jQuery(function () {
  228. SphinxRtdTheme.Navigation.enable(true);
  229. });
  230. </script>
  231. </body>
  232. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.utils.early_stopping &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.utils.early_stopping</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.utils.early_stopping</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">from</span> <span class="nn">super_gradients.training.utils.callbacks</span> <span class="kn">import</span> <span class="n">PhaseCallback</span><span class="p">,</span> <span class="n">Phase</span><span class="p">,</span> <span class="n">PhaseContext</span>
  84. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span>
  85. <span class="kn">import</span> <span class="nn">torch</span>
  86. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  87. <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
  88. <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
  89. <div class="viewcode-block" id="EarlyStop"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.early_stopping.EarlyStop">[docs]</a><span class="k">class</span> <span class="nc">EarlyStop</span><span class="p">(</span><span class="n">PhaseCallback</span><span class="p">):</span>
  90. <span class="sd">&quot;&quot;&quot;</span>
  91. <span class="sd"> Callback to monitor a metric and stop training when it stops improving.</span>
  92. <span class="sd"> Inspired by pytorch_lightning.callbacks.early_stopping and tf.keras.callbacks.EarlyStopping</span>
  93. <span class="sd"> &quot;&quot;&quot;</span>
  94. <span class="n">mode_dict</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;min&quot;</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">lt</span><span class="p">,</span> <span class="s2">&quot;max&quot;</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">gt</span><span class="p">}</span>
  95. <span class="n">supported_phases</span> <span class="o">=</span> <span class="p">(</span><span class="n">Phase</span><span class="o">.</span><span class="n">VALIDATION_EPOCH_END</span><span class="p">,</span> <span class="n">Phase</span><span class="o">.</span><span class="n">TRAIN_EPOCH_END</span><span class="p">)</span>
  96. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
  97. <span class="n">phase</span><span class="p">:</span> <span class="n">Phase</span><span class="p">,</span>
  98. <span class="n">monitor</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
  99. <span class="n">mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;min&quot;</span><span class="p">,</span>
  100. <span class="n">min_delta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
  101. <span class="n">patience</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">3</span><span class="p">,</span>
  102. <span class="n">check_finite</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
  103. <span class="n">threshold</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
  104. <span class="n">verbose</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
  105. <span class="n">strict</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span>
  106. <span class="p">):</span>
  107. <span class="sd">&quot;&quot;&quot;</span>
  108. <span class="sd"> :param phase: Callback phase event.</span>
  109. <span class="sd"> :param monitor: name of the metric to be monitored.</span>
  110. <span class="sd"> :param mode: one of &#39;min&#39;, &#39;max&#39;. In &#39;min&#39; mode, training will stop when the quantity</span>
  111. <span class="sd"> monitored has stopped decreasing and in &#39;max&#39; mode it will stop when the quantity</span>
  112. <span class="sd"> monitored has stopped increasing.</span>
  113. <span class="sd"> :param min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute</span>
  114. <span class="sd"> change of less than `min_delta`, will count as no improvement.</span>
  115. <span class="sd"> :param patience: number of checks with no improvement after which training will be stopped.</span>
  116. <span class="sd"> One check happens after every phase event.</span>
  117. <span class="sd"> :param check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite.</span>
  118. <span class="sd"> :param threshold: Stop training immediately once the monitored quantity reaches this threshold. For mode &#39;min&#39;</span>
  119. <span class="sd"> stops training when below threshold, For mode &#39;max&#39; stops training when above threshold.</span>
  120. <span class="sd"> :param verbose: If `True` print logs.</span>
  121. <span class="sd"> :param strict: whether to crash the training if `monitor` is not found in the metrics.</span>
  122. <span class="sd"> &quot;&quot;&quot;</span>
  123. <span class="nb">super</span><span class="p">(</span><span class="n">EarlyStop</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">phase</span><span class="p">)</span>
  124. <span class="k">if</span> <span class="n">phase</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">supported_phases</span><span class="p">:</span>
  125. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;EarlyStop doesn&#39;t support phase: </span><span class="si">{</span><span class="n">phase</span><span class="si">}</span><span class="s2">, &quot;</span>
  126. <span class="sa">f</span><span class="s2">&quot;excepted </span><span class="si">{</span><span class="s1">&#39;, &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="nb">str</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">supported_phases</span><span class="p">])</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
  127. <span class="bp">self</span><span class="o">.</span><span class="n">phase</span> <span class="o">=</span> <span class="n">phase</span>
  128. <span class="bp">self</span><span class="o">.</span><span class="n">monitor_key</span> <span class="o">=</span> <span class="n">monitor</span>
  129. <span class="bp">self</span><span class="o">.</span><span class="n">min_delta</span> <span class="o">=</span> <span class="n">min_delta</span>
  130. <span class="bp">self</span><span class="o">.</span><span class="n">patience</span> <span class="o">=</span> <span class="n">patience</span>
  131. <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">=</span> <span class="n">mode</span>
  132. <span class="bp">self</span><span class="o">.</span><span class="n">check_finite</span> <span class="o">=</span> <span class="n">check_finite</span>
  133. <span class="bp">self</span><span class="o">.</span><span class="n">threshold</span> <span class="o">=</span> <span class="n">threshold</span>
  134. <span class="bp">self</span><span class="o">.</span><span class="n">verbose</span> <span class="o">=</span> <span class="n">verbose</span>
  135. <span class="bp">self</span><span class="o">.</span><span class="n">strict</span> <span class="o">=</span> <span class="n">strict</span>
  136. <span class="bp">self</span><span class="o">.</span><span class="n">wait_count</span> <span class="o">=</span> <span class="mi">0</span>
  137. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode_dict</span><span class="p">:</span>
  138. <span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;`mode` can be </span><span class="si">{</span><span class="s1">&#39;, &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mode_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span><span class="si">}</span><span class="s2">, got </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
  139. <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode_dict</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="p">]</span>
  140. <span class="bp">self</span><span class="o">.</span><span class="n">min_delta</span> <span class="o">*=</span> <span class="mi">1</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">gt</span> <span class="k">else</span> <span class="o">-</span><span class="mi">1</span>
  141. <span class="n">torch_inf</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">Inf</span><span class="p">)</span>
  142. <span class="bp">self</span><span class="o">.</span><span class="n">best_score</span> <span class="o">=</span> <span class="n">torch_inf</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span> <span class="o">==</span> <span class="n">torch</span><span class="o">.</span><span class="n">lt</span> <span class="k">else</span> <span class="o">-</span><span class="n">torch_inf</span>
  143. <span class="k">def</span> <span class="nf">_get_metric_value</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">metrics_dict</span><span class="p">):</span>
  144. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">monitor_key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">metrics_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
  145. <span class="n">msg</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;Can&#39;t find EarlyStop monitor </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor_key</span><span class="si">}</span><span class="s2"> in metrics_dict: </span><span class="si">{</span><span class="n">metrics_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span>
  146. <span class="n">exception_cls</span> <span class="o">=</span> <span class="ne">RuntimeError</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">strict</span> <span class="k">else</span> <span class="n">MissingMonitorKeyException</span>
  147. <span class="k">raise</span> <span class="n">exception_cls</span><span class="p">(</span><span class="n">msg</span><span class="p">)</span>
  148. <span class="k">return</span> <span class="n">metrics_dict</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor_key</span><span class="p">]</span>
  149. <span class="k">def</span> <span class="nf">_check_for_early_stop</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">current</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
  150. <span class="n">should_stop</span> <span class="o">=</span> <span class="kc">False</span>
  151. <span class="c1"># check if current value is Nan or inf</span>
  152. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">check_finite</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">torch</span><span class="o">.</span><span class="n">isfinite</span><span class="p">(</span><span class="n">current</span><span class="p">):</span>
  153. <span class="n">should_stop</span> <span class="o">=</span> <span class="kc">True</span>
  154. <span class="n">reason</span> <span class="o">=</span> <span class="p">(</span>
  155. <span class="sa">f</span><span class="s2">&quot;Monitored metric </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor_key</span><span class="si">}</span><span class="s2"> = </span><span class="si">{</span><span class="n">current</span><span class="si">}</span><span class="s2"> is not finite.&quot;</span>
  156. <span class="sa">f</span><span class="s2">&quot; Previous best value was </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">best_score</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2">. Signaling Trainer to stop.&quot;</span>
  157. <span class="p">)</span>
  158. <span class="c1"># check if current value reached threshold value</span>
  159. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">threshold</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span><span class="p">(</span><span class="n">current</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">threshold</span><span class="p">):</span>
  160. <span class="n">should_stop</span> <span class="o">=</span> <span class="kc">True</span>
  161. <span class="n">reason</span> <span class="o">=</span> <span class="p">(</span>
  162. <span class="s2">&quot;Stopping threshold reached:&quot;</span>
  163. <span class="sa">f</span><span class="s2">&quot; </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor_key</span><span class="si">}</span><span class="s2"> = </span><span class="si">{</span><span class="n">current</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">threshold</span><span class="si">}</span><span class="s2">.&quot;</span>
  164. <span class="s2">&quot; Signaling Trainer to stop.&quot;</span>
  165. <span class="p">)</span>
  166. <span class="c1"># check if current is an improvement of monitor_key metric.</span>
  167. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">monitor_op</span><span class="p">(</span><span class="n">current</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">min_delta</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">best_score</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">current</span><span class="o">.</span><span class="n">device</span><span class="p">)):</span>
  168. <span class="n">should_stop</span> <span class="o">=</span> <span class="kc">False</span>
  169. <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">isfinite</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">best_score</span><span class="p">):</span>
  170. <span class="n">reason</span> <span class="o">=</span> <span class="p">(</span>
  171. <span class="sa">f</span><span class="s2">&quot;Metric </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor_key</span><span class="si">}</span><span class="s2"> improved by </span><span class="si">{</span><span class="nb">abs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">best_score</span> <span class="o">-</span> <span class="n">current</span><span class="p">)</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2"> &gt;=&quot;</span>
  172. <span class="sa">f</span><span class="s2">&quot; min_delta = </span><span class="si">{</span><span class="nb">abs</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">min_delta</span><span class="p">)</span><span class="si">}</span><span class="s2">. New best score: </span><span class="si">{</span><span class="n">current</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2">&quot;</span>
  173. <span class="p">)</span>
  174. <span class="k">else</span><span class="p">:</span>
  175. <span class="n">reason</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;Metric </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor_key</span><span class="si">}</span><span class="s2"> improved. New best score: </span><span class="si">{</span><span class="n">current</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2">&quot;</span>
  176. <span class="bp">self</span><span class="o">.</span><span class="n">best_score</span> <span class="o">=</span> <span class="n">current</span>
  177. <span class="bp">self</span><span class="o">.</span><span class="n">wait_count</span> <span class="o">=</span> <span class="mi">0</span>
  178. <span class="c1"># no improvement in monitor_key metric, check if wait_count is bigger than patience.</span>
  179. <span class="k">else</span><span class="p">:</span>
  180. <span class="bp">self</span><span class="o">.</span><span class="n">wait_count</span> <span class="o">+=</span> <span class="mi">1</span>
  181. <span class="n">reason</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;Monitored metric </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">monitor_key</span><span class="si">}</span><span class="s2"> did not improve in the last </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">wait_count</span><span class="si">}</span><span class="s2"> records.&quot;</span>
  182. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">wait_count</span> <span class="o">&gt;=</span> <span class="bp">self</span><span class="o">.</span><span class="n">patience</span><span class="p">:</span>
  183. <span class="n">should_stop</span> <span class="o">=</span> <span class="kc">True</span>
  184. <span class="n">reason</span> <span class="o">+=</span> <span class="sa">f</span><span class="s2">&quot; Best score: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">best_score</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2">. Signaling Trainer to stop.&quot;</span>
  185. <span class="k">return</span> <span class="n">reason</span><span class="p">,</span> <span class="n">should_stop</span>
  186. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">):</span>
  187. <span class="k">try</span><span class="p">:</span>
  188. <span class="n">current</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_metric_value</span><span class="p">(</span><span class="n">context</span><span class="o">.</span><span class="n">metrics_dict</span><span class="p">)</span>
  189. <span class="k">except</span> <span class="n">MissingMonitorKeyException</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
  190. <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="n">e</span><span class="p">)</span>
  191. <span class="k">return</span>
  192. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">current</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
  193. <span class="n">current</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">current</span><span class="p">)</span>
  194. <span class="n">reason</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">should_stop</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_check_for_early_stop</span><span class="p">(</span><span class="n">current</span><span class="p">)</span>
  195. <span class="c1"># log reason message, and signal early stop if should_stop=True.</span>
  196. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">should_stop</span><span class="p">:</span>
  197. <span class="bp">self</span><span class="o">.</span><span class="n">_signal_early_stop</span><span class="p">(</span><span class="n">context</span><span class="p">,</span> <span class="n">reason</span><span class="p">)</span>
  198. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">verbose</span><span class="p">:</span>
  199. <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="n">reason</span><span class="p">)</span>
  200. <span class="k">def</span> <span class="nf">_signal_early_stop</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">context</span><span class="p">:</span> <span class="n">PhaseContext</span><span class="p">,</span> <span class="n">reason</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
  201. <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="n">reason</span><span class="p">)</span>
  202. <span class="n">context</span><span class="o">.</span><span class="n">update_context</span><span class="p">(</span><span class="n">stop_training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></div>
  203. <div class="viewcode-block" id="MissingMonitorKeyException"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.early_stopping.MissingMonitorKeyException">[docs]</a><span class="k">class</span> <span class="nc">MissingMonitorKeyException</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
  204. <span class="sd">&quot;&quot;&quot;</span>
  205. <span class="sd"> Exception raised for missing monitor key in metrics_dict.</span>
  206. <span class="sd"> &quot;&quot;&quot;</span>
  207. <span class="k">pass</span></div>
  208. </pre></div>
  209. </div>
  210. </div>
  211. <footer>
  212. <hr/>
  213. <div role="contentinfo">
  214. <p>&#169; Copyright 2021, SuperGradients team.</p>
  215. </div>
  216. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  217. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  218. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  219. </footer>
  220. </div>
  221. </div>
  222. </section>
  223. </div>
  224. <script>
  225. jQuery(function () {
  226. SphinxRtdTheme.Navigation.enable(true);
  227. });
  228. </script>
  229. </body>
  230. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.utils.ema &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.utils.ema</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.utils.ema</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">math</span>
  84. <span class="kn">import</span> <span class="nn">warnings</span>
  85. <span class="kn">from</span> <span class="nn">copy</span> <span class="kn">import</span> <span class="n">deepcopy</span>
  86. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span>
  87. <span class="kn">import</span> <span class="nn">torch</span>
  88. <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
  89. <span class="kn">from</span> <span class="nn">super_gradients.training</span> <span class="kn">import</span> <span class="n">utils</span> <span class="k">as</span> <span class="n">core_utils</span>
  90. <span class="kn">from</span> <span class="nn">super_gradients.training.models</span> <span class="kn">import</span> <span class="n">SgModule</span>
  91. <span class="kn">from</span> <span class="nn">super_gradients.training.models.kd_modules.kd_module</span> <span class="kn">import</span> <span class="n">KDModule</span>
  92. <div class="viewcode-block" id="copy_attr"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.ema.copy_attr">[docs]</a><span class="k">def</span> <span class="nf">copy_attr</span><span class="p">(</span><span class="n">a</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">b</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">include</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">]</span> <span class="o">=</span> <span class="p">(),</span> <span class="n">exclude</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">list</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">]</span> <span class="o">=</span> <span class="p">()):</span>
  93. <span class="c1"># Copy attributes from b to a, options to only include [...] and to exclude [...]</span>
  94. <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">b</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
  95. <span class="k">if</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">include</span><span class="p">)</span> <span class="ow">and</span> <span class="n">k</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">include</span><span class="p">)</span> <span class="ow">or</span> <span class="n">k</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s1">&#39;_&#39;</span><span class="p">)</span> <span class="ow">or</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">exclude</span><span class="p">:</span>
  96. <span class="k">continue</span>
  97. <span class="k">else</span><span class="p">:</span>
  98. <span class="nb">setattr</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span></div>
  99. <div class="viewcode-block" id="ModelEMA"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.ema.ModelEMA">[docs]</a><span class="k">class</span> <span class="nc">ModelEMA</span><span class="p">:</span>
  100. <span class="sd">&quot;&quot;&quot; Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models</span>
  101. <span class="sd"> Keep a moving average of everything in the model state_dict (parameters and buffers).</span>
  102. <span class="sd"> This is intended to allow functionality like</span>
  103. <span class="sd"> https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage</span>
  104. <span class="sd"> A smoothed version of the weights is necessary for some training schemes to perform well.</span>
  105. <span class="sd"> This class is sensitive where it is initialized in the sequence of model init,</span>
  106. <span class="sd"> GPU assignment and distributed training wrappers.</span>
  107. <span class="sd"> &quot;&quot;&quot;</span>
  108. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">decay</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.9999</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">15</span><span class="p">,</span> <span class="n">exp_activation</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
  109. <span class="sd">&quot;&quot;&quot;</span>
  110. <span class="sd"> Init the EMA</span>
  111. <span class="sd"> :param model: Union[SgModule, nn.Module], the training model to construct the EMA model by</span>
  112. <span class="sd"> IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE</span>
  113. <span class="sd"> AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.</span>
  114. <span class="sd"> :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value</span>
  115. <span class="sd"> until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)</span>
  116. <span class="sd"> :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to</span>
  117. <span class="sd"> its final value. beta=15 is ~40% of the training process.</span>
  118. <span class="sd"> &quot;&quot;&quot;</span>
  119. <span class="c1"># Create EMA</span>
  120. <span class="bp">self</span><span class="o">.</span><span class="n">ema</span> <span class="o">=</span> <span class="n">deepcopy</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>
  121. <span class="bp">self</span><span class="o">.</span><span class="n">ema</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
  122. <span class="k">if</span> <span class="n">exp_activation</span><span class="p">:</span>
  123. <span class="bp">self</span><span class="o">.</span><span class="n">decay_function</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">decay</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">math</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">x</span> <span class="o">*</span> <span class="n">beta</span><span class="p">))</span> <span class="c1"># decay exponential ramp (to help early epochs)</span>
  124. <span class="k">else</span><span class="p">:</span>
  125. <span class="bp">self</span><span class="o">.</span><span class="n">decay_function</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">decay</span> <span class="c1"># always return the same decay factor</span>
  126. <span class="sd">&quot;&quot;&quot;&quot;</span>
  127. <span class="sd"> we hold a list of model attributes (not wights and biases) which we would like to include in each</span>
  128. <span class="sd"> attribute update or exclude from each update. a SgModule declare these attribute using</span>
  129. <span class="sd"> get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule</span>
  130. <span class="sd"> all non-private (not starting with &#39;_&#39;) attributes will be updated (and only them).</span>
  131. <span class="sd"> &quot;&quot;&quot;</span>
  132. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">module</span><span class="p">,</span> <span class="n">SgModule</span><span class="p">):</span>
  133. <span class="bp">self</span><span class="o">.</span><span class="n">include_attributes</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">get_include_attributes</span><span class="p">()</span>
  134. <span class="bp">self</span><span class="o">.</span><span class="n">exclude_attributes</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">get_exclude_attributes</span><span class="p">()</span>
  135. <span class="k">else</span><span class="p">:</span>
  136. <span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span><span class="s2">&quot;Warning: EMA should be used with SgModule instance. All attributes of the model will be &quot;</span>
  137. <span class="s2">&quot;included in EMA&quot;</span><span class="p">)</span>
  138. <span class="bp">self</span><span class="o">.</span><span class="n">include_attributes</span> <span class="o">=</span> <span class="p">[]</span>
  139. <span class="bp">self</span><span class="o">.</span><span class="n">exclude_attributes</span> <span class="o">=</span> <span class="p">[]</span>
  140. <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">ema</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
  141. <span class="n">p</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="kc">False</span><span class="p">)</span>
  142. <div class="viewcode-block" id="ModelEMA.update"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.ema.ModelEMA.update">[docs]</a> <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">training_percent</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
  143. <span class="sd">&quot;&quot;&quot;</span>
  144. <span class="sd"> Update the state of the EMA model.</span>
  145. <span class="sd"> :param model: current training model</span>
  146. <span class="sd"> :param training_percent: the percentage of the training process [0,1]. i.e 0.4 means 40% of the training have passed</span>
  147. <span class="sd"> &quot;&quot;&quot;</span>
  148. <span class="c1"># Update EMA parameters</span>
  149. <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
  150. <span class="n">decay</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decay_function</span><span class="p">(</span><span class="n">training_percent</span><span class="p">)</span>
  151. <span class="k">for</span> <span class="n">ema_v</span><span class="p">,</span> <span class="n">model_v</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ema</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span><span class="o">.</span><span class="n">values</span><span class="p">(),</span> <span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span><span class="o">.</span><span class="n">values</span><span class="p">()):</span>
  152. <span class="k">if</span> <span class="n">ema_v</span><span class="o">.</span><span class="n">dtype</span><span class="o">.</span><span class="n">is_floating_point</span><span class="p">:</span>
  153. <span class="n">ema_v</span><span class="o">.</span><span class="n">copy_</span><span class="p">(</span><span class="n">ema_v</span> <span class="o">*</span> <span class="n">decay</span> <span class="o">+</span> <span class="p">(</span><span class="mf">1.</span> <span class="o">-</span> <span class="n">decay</span><span class="p">)</span> <span class="o">*</span> <span class="n">model_v</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span></div>
  154. <div class="viewcode-block" id="ModelEMA.update_attr"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.ema.ModelEMA.update_attr">[docs]</a> <span class="k">def</span> <span class="nf">update_attr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">):</span>
  155. <span class="sd">&quot;&quot;&quot;</span>
  156. <span class="sd"> This function updates model attributes (not weight and biases) from original model to the ema model.</span>
  157. <span class="sd"> attributes of the original model, such as anchors and grids (of detection models), may be crucial to the</span>
  158. <span class="sd"> model operation and need to be updated.</span>
  159. <span class="sd"> If include_attributes and exclude_attributes lists were not defined, all non-private (not starting with &#39;_&#39;)</span>
  160. <span class="sd"> attributes will be updated (and only them).</span>
  161. <span class="sd"> :param model: the source model</span>
  162. <span class="sd"> &quot;&quot;&quot;</span>
  163. <span class="n">copy_attr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ema</span><span class="o">.</span><span class="n">module</span><span class="p">,</span> <span class="n">model</span><span class="o">.</span><span class="n">module</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">include_attributes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">exclude_attributes</span><span class="p">)</span></div></div>
  164. <div class="viewcode-block" id="KDModelEMA"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.ema.KDModelEMA">[docs]</a><span class="k">class</span> <span class="nc">KDModelEMA</span><span class="p">(</span><span class="n">ModelEMA</span><span class="p">):</span>
  165. <span class="sd">&quot;&quot;&quot; Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models</span>
  166. <span class="sd"> Keep a moving average of everything in the model state_dict (parameters and buffers).</span>
  167. <span class="sd"> This is intended to allow functionality like</span>
  168. <span class="sd"> https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage</span>
  169. <span class="sd"> A smoothed version of the weights is necessary for some training schemes to perform well.</span>
  170. <span class="sd"> This class is sensitive where it is initialized in the sequence of model init,</span>
  171. <span class="sd"> GPU assignment and distributed training wrappers.</span>
  172. <span class="sd"> &quot;&quot;&quot;</span>
  173. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">kd_model</span><span class="p">:</span> <span class="n">KDModule</span><span class="p">,</span> <span class="n">decay</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.9999</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mi">15</span><span class="p">,</span> <span class="n">exp_activation</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
  174. <span class="sd">&quot;&quot;&quot;</span>
  175. <span class="sd"> Init the EMA</span>
  176. <span class="sd"> :param kd_model: KDModule, the training Knowledge distillation model to construct the EMA model by</span>
  177. <span class="sd"> IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE</span>
  178. <span class="sd"> AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.</span>
  179. <span class="sd"> :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value</span>
  180. <span class="sd"> until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)</span>
  181. <span class="sd"> :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to</span>
  182. <span class="sd"> its final value. beta=15 is ~40% of the training process.</span>
  183. <span class="sd"> &quot;&quot;&quot;</span>
  184. <span class="c1"># Only work on the student (we don&#39;t want to update and to have a duplicate of the teacher)</span>
  185. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">model</span><span class="o">=</span><span class="n">core_utils</span><span class="o">.</span><span class="n">WrappedModel</span><span class="p">(</span><span class="n">kd_model</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">student</span><span class="p">),</span>
  186. <span class="n">decay</span><span class="o">=</span><span class="n">decay</span><span class="p">,</span>
  187. <span class="n">beta</span><span class="o">=</span><span class="n">beta</span><span class="p">,</span>
  188. <span class="n">exp_activation</span><span class="o">=</span><span class="n">exp_activation</span><span class="p">)</span>
  189. <span class="c1"># Overwrite current ema attribute with combination of the student model EMA (current self.ema)</span>
  190. <span class="c1"># with already the instantiated teacher, to have the final KD EMA</span>
  191. <span class="bp">self</span><span class="o">.</span><span class="n">ema</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">WrappedModel</span><span class="p">(</span><span class="n">KDModule</span><span class="p">(</span><span class="n">arch_params</span><span class="o">=</span><span class="n">kd_model</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">arch_params</span><span class="p">,</span>
  192. <span class="n">student</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">ema</span><span class="o">.</span><span class="n">module</span><span class="p">,</span>
  193. <span class="n">teacher</span><span class="o">=</span><span class="n">kd_model</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">teacher</span><span class="p">,</span>
  194. <span class="n">run_teacher_on_eval</span><span class="o">=</span><span class="n">kd_model</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">run_teacher_on_eval</span><span class="p">))</span></div>
  195. </pre></div>
  196. </div>
  197. </div>
  198. <footer>
  199. <hr/>
  200. <div role="contentinfo">
  201. <p>&#169; Copyright 2021, SuperGradients team.</p>
  202. </div>
  203. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  204. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  205. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  206. </footer>
  207. </div>
  208. </div>
  209. </section>
  210. </div>
  211. <script>
  212. jQuery(function () {
  213. SphinxRtdTheme.Navigation.enable(true);
  214. });
  215. </script>
  216. </body>
  217. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.utils.export_utils &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.utils.export_utils</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.utils.export_utils</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">torch</span>
  84. <span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
  85. <div class="viewcode-block" id="fuse_conv_bn"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.export_utils.fuse_conv_bn">[docs]</a><span class="k">def</span> <span class="nf">fuse_conv_bn</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">replace_bn_with_identity</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
  86. <span class="sd">&quot;&quot;&quot;</span>
  87. <span class="sd"> Fuses consecutive nn.Conv2d and nn.BatchNorm2d layers recursively inplace in all of the model</span>
  88. <span class="sd"> :param replace_bn_with_identity: if set to true, bn will be replaced with identity. otherwise, bn will be removed</span>
  89. <span class="sd"> :param model: the target model</span>
  90. <span class="sd"> :return: the number of fuses executed</span>
  91. <span class="sd"> &quot;&quot;&quot;</span>
  92. <span class="n">children</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">named_children</span><span class="p">())</span>
  93. <span class="n">counter</span> <span class="o">=</span> <span class="mi">0</span>
  94. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">children</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
  95. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">children</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">children</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">][</span><span class="mi">1</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">):</span>
  96. <span class="nb">setattr</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">children</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">],</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">fuse_conv_bn_eval</span><span class="p">(</span><span class="n">children</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">],</span> <span class="n">children</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">][</span><span class="mi">1</span><span class="p">]))</span>
  97. <span class="k">if</span> <span class="n">replace_bn_with_identity</span><span class="p">:</span>
  98. <span class="nb">setattr</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">children</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">][</span><span class="mi">0</span><span class="p">],</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">())</span>
  99. <span class="k">else</span><span class="p">:</span>
  100. <span class="nb">delattr</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">children</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">][</span><span class="mi">0</span><span class="p">])</span>
  101. <span class="n">counter</span> <span class="o">+=</span> <span class="mi">1</span>
  102. <span class="k">for</span> <span class="n">child_name</span><span class="p">,</span> <span class="n">child</span> <span class="ow">in</span> <span class="n">children</span><span class="p">:</span>
  103. <span class="n">counter</span> <span class="o">+=</span> <span class="n">fuse_conv_bn</span><span class="p">(</span><span class="n">child</span><span class="p">,</span> <span class="n">replace_bn_with_identity</span><span class="p">)</span>
  104. <span class="k">return</span> <span class="n">counter</span></div>
  105. </pre></div>
  106. </div>
  107. </div>
  108. <footer>
  109. <hr/>
  110. <div role="contentinfo">
  111. <p>&#169; Copyright 2021, SuperGradients team.</p>
  112. </div>
  113. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  114. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  115. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  116. </footer>
  117. </div>
  118. </div>
  119. </section>
  120. </div>
  121. <script>
  122. jQuery(function () {
  123. SphinxRtdTheme.Navigation.enable(true);
  124. });
  125. </script>
  126. </body>
  127. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.utils.module_utils &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.utils.module_utils</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.utils.module_utils</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">from</span> <span class="nn">collections</span> <span class="kn">import</span> <span class="n">OrderedDict</span>
  84. <span class="kn">import</span> <span class="nn">copy</span>
  85. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Tuple</span>
  86. <span class="kn">import</span> <span class="nn">torch</span>
  87. <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
  88. <div class="viewcode-block" id="MultiOutputModule"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.module_utils.MultiOutputModule">[docs]</a><span class="k">class</span> <span class="nc">MultiOutputModule</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
  89. <span class="sd">&quot;&quot;&quot;</span>
  90. <span class="sd"> This module wraps around a container nn.Module (such as Module, Sequential and ModuleList) and allows to extract</span>
  91. <span class="sd"> multiple output from its inner modules on each forward call() (as a list of output tensors)</span>
  92. <span class="sd"> note: the default output of the wrapped module will not be added to the output list by default. To get</span>
  93. <span class="sd"> the default output in the outputs list, explicitly include its path in the @output_paths parameter</span>
  94. <span class="sd"> i.e.</span>
  95. <span class="sd"> for module:</span>
  96. <span class="sd"> Sequential(</span>
  97. <span class="sd"> (0): Sequential(</span>
  98. <span class="sd"> (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)</span>
  99. <span class="sd"> (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)</span>
  100. <span class="sd"> (2): ReLU6(inplace=True)</span>
  101. <span class="sd"> ) ===================================&gt;&gt;</span>
  102. <span class="sd"> (1): InvertedResidual(</span>
  103. <span class="sd"> (conv): Sequential(</span>
  104. <span class="sd"> (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)</span>
  105. <span class="sd"> (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)</span>
  106. <span class="sd"> (2): ReLU6(inplace=True) ===================================&gt;&gt;</span>
  107. <span class="sd"> (3): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)</span>
  108. <span class="sd"> (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)</span>
  109. <span class="sd"> )</span>
  110. <span class="sd"> )</span>
  111. <span class="sd"> )</span>
  112. <span class="sd"> and paths:</span>
  113. <span class="sd"> [0, [1, &#39;conv&#39;, 2]]</span>
  114. <span class="sd"> the output are marked with arrows</span>
  115. <span class="sd"> &quot;&quot;&quot;</span>
  116. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">output_paths</span><span class="p">:</span> <span class="nb">list</span><span class="p">,</span> <span class="n">prune</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
  117. <span class="sd">&quot;&quot;&quot;</span>
  118. <span class="sd"> :param module: The wrapped container module</span>
  119. <span class="sd"> :param output_paths: a list of lists or keys containing the canonical paths to the outputs</span>
  120. <span class="sd"> i.e. [3, [4, &#39;conv&#39;, 5], 7] will extract outputs of layers 3, 7 and 4-&gt;conv-&gt;5</span>
  121. <span class="sd"> &quot;&quot;&quot;</span>
  122. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
  123. <span class="bp">self</span><span class="o">.</span><span class="n">output_paths</span> <span class="o">=</span> <span class="n">output_paths</span>
  124. <span class="bp">self</span><span class="o">.</span><span class="n">_modules</span><span class="p">[</span><span class="s1">&#39;0&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">module</span>
  125. <span class="bp">self</span><span class="o">.</span><span class="n">_outputs_lists</span> <span class="o">=</span> <span class="p">{}</span>
  126. <span class="k">for</span> <span class="n">path</span> <span class="ow">in</span> <span class="n">output_paths</span><span class="p">:</span>
  127. <span class="n">child</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_recursive</span><span class="p">(</span><span class="n">module</span><span class="p">,</span> <span class="n">path</span><span class="p">)</span>
  128. <span class="n">child</span><span class="o">.</span><span class="n">register_forward_hook</span><span class="p">(</span><span class="n">hook</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">save_output_hook</span><span class="p">)</span>
  129. <span class="c1"># PRUNE THE MODULE TO SUPPORT ALL PROVIDED OUTPUT_PATHS BUT REMOVE ALL REDUNDANT LAYERS</span>
  130. <span class="k">if</span> <span class="n">prune</span><span class="p">:</span>
  131. <span class="bp">self</span><span class="o">.</span><span class="n">_prune</span><span class="p">(</span><span class="n">module</span><span class="p">,</span> <span class="n">output_paths</span><span class="p">)</span>
  132. <div class="viewcode-block" id="MultiOutputModule.save_output_hook"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.module_utils.MultiOutputModule.save_output_hook">[docs]</a> <span class="k">def</span> <span class="nf">save_output_hook</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">output</span><span class="p">):</span>
  133. <span class="bp">self</span><span class="o">.</span><span class="n">_outputs_lists</span><span class="p">[</span><span class="nb">input</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">device</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">output</span><span class="p">)</span></div>
  134. <div class="viewcode-block" id="MultiOutputModule.forward"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.module_utils.MultiOutputModule.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">:</span>
  135. <span class="bp">self</span><span class="o">.</span><span class="n">_outputs_lists</span><span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
  136. <span class="bp">self</span><span class="o">.</span><span class="n">_modules</span><span class="p">[</span><span class="s1">&#39;0&#39;</span><span class="p">](</span><span class="n">x</span><span class="p">)</span>
  137. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_outputs_lists</span><span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">]</span></div>
  138. <span class="k">def</span> <span class="nf">_get_recursive</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">path</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">:</span>
  139. <span class="sd">&quot;&quot;&quot;recursively look for a module using a path&quot;&quot;&quot;</span>
  140. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
  141. <span class="k">return</span> <span class="n">module</span><span class="o">.</span><span class="n">_modules</span><span class="p">[</span><span class="nb">str</span><span class="p">(</span><span class="n">path</span><span class="p">)]</span>
  142. <span class="k">elif</span> <span class="nb">len</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
  143. <span class="k">return</span> <span class="n">module</span><span class="o">.</span><span class="n">_modules</span><span class="p">[</span><span class="nb">str</span><span class="p">(</span><span class="n">path</span><span class="p">[</span><span class="mi">0</span><span class="p">])]</span>
  144. <span class="k">else</span><span class="p">:</span>
  145. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_recursive</span><span class="p">(</span><span class="n">module</span><span class="o">.</span><span class="n">_modules</span><span class="p">[</span><span class="nb">str</span><span class="p">(</span><span class="n">path</span><span class="p">[</span><span class="mi">0</span><span class="p">])],</span> <span class="n">path</span><span class="p">[</span><span class="mi">1</span><span class="p">:])</span>
  146. <span class="k">def</span> <span class="nf">_prune</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">output_paths</span><span class="p">:</span> <span class="nb">list</span><span class="p">):</span>
  147. <span class="sd">&quot;&quot;&quot;</span>
  148. <span class="sd"> Recursively prune the module to support all provided output_paths but remove all redundant layers</span>
  149. <span class="sd"> &quot;&quot;&quot;</span>
  150. <span class="n">last_index</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span>
  151. <span class="n">last_key</span> <span class="o">=</span> <span class="kc">None</span>
  152. <span class="c1"># look for the last key from all paths</span>
  153. <span class="k">for</span> <span class="n">path</span> <span class="ow">in</span> <span class="n">output_paths</span><span class="p">:</span>
  154. <span class="n">key</span> <span class="o">=</span> <span class="n">path</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="k">else</span> <span class="n">path</span>
  155. <span class="n">index</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">module</span><span class="o">.</span><span class="n">_modules</span><span class="p">)</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">key</span><span class="p">))</span>
  156. <span class="k">if</span> <span class="n">index</span> <span class="o">&gt;</span> <span class="n">last_index</span><span class="p">:</span>
  157. <span class="n">last_index</span> <span class="o">=</span> <span class="n">index</span>
  158. <span class="n">last_key</span> <span class="o">=</span> <span class="n">key</span>
  159. <span class="n">module</span><span class="o">.</span><span class="n">_modules</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_slice_odict</span><span class="p">(</span><span class="n">module</span><span class="o">.</span><span class="n">_modules</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">last_index</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
  160. <span class="n">next_level_paths</span> <span class="o">=</span> <span class="p">[]</span>
  161. <span class="k">for</span> <span class="n">path</span> <span class="ow">in</span> <span class="n">output_paths</span><span class="p">:</span>
  162. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="ow">and</span> <span class="n">path</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">last_key</span> <span class="ow">and</span> <span class="nb">len</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
  163. <span class="n">next_level_paths</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">path</span><span class="p">[</span><span class="mi">1</span><span class="p">:])</span>
  164. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">next_level_paths</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
  165. <span class="bp">self</span><span class="o">.</span><span class="n">_prune</span><span class="p">(</span><span class="n">module</span><span class="o">.</span><span class="n">_modules</span><span class="p">[</span><span class="nb">str</span><span class="p">(</span><span class="n">last_key</span><span class="p">)],</span> <span class="n">next_level_paths</span><span class="p">)</span>
  166. <span class="k">def</span> <span class="nf">_slice_odict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">odict</span><span class="p">:</span> <span class="n">OrderedDict</span><span class="p">,</span> <span class="n">start</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">end</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
  167. <span class="sd">&quot;&quot;&quot;Slice an OrderedDict in the same logic list,tuple... are sliced&quot;&quot;&quot;</span>
  168. <span class="k">return</span> <span class="n">OrderedDict</span><span class="p">([</span>
  169. <span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span> <span class="k">for</span> <span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span> <span class="ow">in</span> <span class="n">odict</span><span class="o">.</span><span class="n">items</span><span class="p">()</span>
  170. <span class="k">if</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="n">odict</span><span class="o">.</span><span class="n">keys</span><span class="p">())[</span><span class="n">start</span><span class="p">:</span><span class="n">end</span><span class="p">]</span>
  171. <span class="p">])</span></div>
  172. <span class="k">def</span> <span class="nf">_replace_activations_recursive</span><span class="p">(</span><span class="n">module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">new_activation</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">activations_to_replace</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">type</span><span class="p">]):</span>
  173. <span class="sd">&quot;&quot;&quot;</span>
  174. <span class="sd"> A helper called in replace_activations(...)</span>
  175. <span class="sd"> &quot;&quot;&quot;</span>
  176. <span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">m</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">named_children</span><span class="p">():</span>
  177. <span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">m</span><span class="p">)</span> <span class="ow">in</span> <span class="n">activations_to_replace</span><span class="p">:</span>
  178. <span class="nb">setattr</span><span class="p">(</span><span class="n">module</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">new_activation</span><span class="p">))</span>
  179. <span class="k">else</span><span class="p">:</span>
  180. <span class="n">_replace_activations_recursive</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">new_activation</span><span class="p">,</span> <span class="n">activations_to_replace</span><span class="p">)</span>
  181. <div class="viewcode-block" id="replace_activations"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.module_utils.replace_activations">[docs]</a><span class="k">def</span> <span class="nf">replace_activations</span><span class="p">(</span><span class="n">module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">new_activation</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">activations_to_replace</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">type</span><span class="p">]):</span>
  182. <span class="sd">&quot;&quot;&quot;</span>
  183. <span class="sd"> Recursively go through module and replaces each activation in activations_to_replace with a copy of new_activation</span>
  184. <span class="sd"> :param module: a module that will be changed inplace</span>
  185. <span class="sd"> :param new_activation: a sample of a new activation (will be copied)</span>
  186. <span class="sd"> :param activations_to_replace: types of activations to replace, each must be a subclass of nn.Module</span>
  187. <span class="sd"> &quot;&quot;&quot;</span>
  188. <span class="c1"># check arguments once before the recursion</span>
  189. <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">new_activation</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">),</span> <span class="s1">&#39;new_activation should be nn.Module&#39;</span>
  190. <span class="k">assert</span> <span class="nb">all</span><span class="p">([</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="nb">type</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">issubclass</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">)</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">activations_to_replace</span><span class="p">]),</span> \
  191. <span class="s1">&#39;activations_to_replace should be types that are subclasses of nn.Module&#39;</span>
  192. <span class="c1"># do the replacement</span>
  193. <span class="n">_replace_activations_recursive</span><span class="p">(</span><span class="n">module</span><span class="p">,</span> <span class="n">new_activation</span><span class="p">,</span> <span class="n">activations_to_replace</span><span class="p">)</span></div>
  194. <div class="viewcode-block" id="fuse_repvgg_blocks_residual_branches"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.module_utils.fuse_repvgg_blocks_residual_branches">[docs]</a><span class="k">def</span> <span class="nf">fuse_repvgg_blocks_residual_branches</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
  195. <span class="sd">&#39;&#39;&#39;</span>
  196. <span class="sd"> Call fuse_block_residual_branches for all repvgg blocks in the model</span>
  197. <span class="sd"> :param model: torch.nn.Module with repvgg blocks. Doesn&#39;t have to be entirely consists of repvgg.</span>
  198. <span class="sd"> :type model: torch.nn.Module</span>
  199. <span class="sd"> &#39;&#39;&#39;</span>
  200. <span class="k">assert</span> <span class="ow">not</span> <span class="n">model</span><span class="o">.</span><span class="n">training</span><span class="p">,</span> <span class="s2">&quot;To fuse RepVGG block residual branches, model must be on eval mode&quot;</span>
  201. <span class="k">for</span> <span class="n">module</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">modules</span><span class="p">():</span>
  202. <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">module</span><span class="p">,</span> <span class="s1">&#39;fuse_block_residual_branches&#39;</span><span class="p">):</span>
  203. <span class="n">module</span><span class="o">.</span><span class="n">fuse_block_residual_branches</span><span class="p">()</span>
  204. <span class="n">model</span><span class="o">.</span><span class="n">build_residual_branches</span> <span class="o">=</span> <span class="kc">False</span></div>
  205. <div class="viewcode-block" id="ConvBNReLU"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.module_utils.ConvBNReLU">[docs]</a><span class="k">class</span> <span class="nc">ConvBNReLU</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
  206. <span class="sd">&quot;&quot;&quot;</span>
  207. <span class="sd"> Class for Convolution2d-Batchnorm2d-Relu layer. Default behaviour is Conv-BN-Relu. To exclude Batchnorm module use</span>
  208. <span class="sd"> `use_normalization=False`, to exclude Relu activation use `use_activation=False`.</span>
  209. <span class="sd"> For convolution arguments documentation see `nn.Conv2d`.</span>
  210. <span class="sd"> For batchnorm arguments documentation see `nn.BatchNorm2d`.</span>
  211. <span class="sd"> For relu arguments documentation see `nn.Relu`.</span>
  212. <span class="sd"> &quot;&quot;&quot;</span>
  213. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
  214. <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
  215. <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
  216. <span class="n">kernel_size</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]],</span>
  217. <span class="n">stride</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
  218. <span class="n">padding</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span>
  219. <span class="n">dilation</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
  220. <span class="n">groups</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
  221. <span class="n">bias</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
  222. <span class="n">padding_mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s1">&#39;zeros&#39;</span><span class="p">,</span>
  223. <span class="n">use_normalization</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
  224. <span class="n">eps</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-5</span><span class="p">,</span>
  225. <span class="n">momentum</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
  226. <span class="n">affine</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
  227. <span class="n">track_running_stats</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
  228. <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
  229. <span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
  230. <span class="n">use_activation</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
  231. <span class="n">inplace</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">):</span>
  232. <span class="nb">super</span><span class="p">(</span><span class="n">ConvBNReLU</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
  233. <span class="bp">self</span><span class="o">.</span><span class="n">seq</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">()</span>
  234. <span class="bp">self</span><span class="o">.</span><span class="n">seq</span><span class="o">.</span><span class="n">add_module</span><span class="p">(</span><span class="s2">&quot;conv&quot;</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span>
  235. <span class="n">out_channels</span><span class="p">,</span>
  236. <span class="n">kernel_size</span><span class="o">=</span><span class="n">kernel_size</span><span class="p">,</span>
  237. <span class="n">stride</span><span class="o">=</span><span class="n">stride</span><span class="p">,</span>
  238. <span class="n">padding</span><span class="o">=</span><span class="n">padding</span><span class="p">,</span>
  239. <span class="n">dilation</span><span class="o">=</span><span class="n">dilation</span><span class="p">,</span>
  240. <span class="n">groups</span><span class="o">=</span><span class="n">groups</span><span class="p">,</span>
  241. <span class="n">bias</span><span class="o">=</span><span class="n">bias</span><span class="p">,</span>
  242. <span class="n">padding_mode</span><span class="o">=</span><span class="n">padding_mode</span><span class="p">))</span>
  243. <span class="k">if</span> <span class="n">use_normalization</span><span class="p">:</span>
  244. <span class="bp">self</span><span class="o">.</span><span class="n">seq</span><span class="o">.</span><span class="n">add_module</span><span class="p">(</span><span class="s2">&quot;bn&quot;</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="n">out_channels</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">eps</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="n">momentum</span><span class="p">,</span> <span class="n">affine</span><span class="o">=</span><span class="n">affine</span><span class="p">,</span>
  245. <span class="n">track_running_stats</span><span class="o">=</span><span class="n">track_running_stats</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">,</span>
  246. <span class="n">dtype</span><span class="o">=</span><span class="n">dtype</span><span class="p">))</span>
  247. <span class="k">if</span> <span class="n">use_activation</span><span class="p">:</span>
  248. <span class="bp">self</span><span class="o">.</span><span class="n">seq</span><span class="o">.</span><span class="n">add_module</span><span class="p">(</span><span class="s2">&quot;relu&quot;</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="n">inplace</span><span class="p">))</span>
  249. <div class="viewcode-block" id="ConvBNReLU.forward"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.module_utils.ConvBNReLU.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
  250. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">seq</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></div></div>
  251. <div class="viewcode-block" id="NormalizationAdapter"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.module_utils.NormalizationAdapter">[docs]</a><span class="k">class</span> <span class="nc">NormalizationAdapter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
  252. <span class="sd">&quot;&quot;&quot;</span>
  253. <span class="sd"> Denormalizes input by mean_original, std_original, then normalizes by mean_required, std_required.</span>
  254. <span class="sd"> Used in KD training where teacher expects data normalized by mean_required, std_required.</span>
  255. <span class="sd"> mean_original, std_original, mean_required, std_required are all list-like objects of length that&#39;s equal to the</span>
  256. <span class="sd"> number of input channels.</span>
  257. <span class="sd"> &quot;&quot;&quot;</span>
  258. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mean_original</span><span class="p">,</span> <span class="n">std_original</span><span class="p">,</span> <span class="n">mean_required</span><span class="p">,</span> <span class="n">std_required</span><span class="p">):</span>
  259. <span class="nb">super</span><span class="p">(</span><span class="n">NormalizationAdapter</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
  260. <span class="n">mean_original</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">mean_original</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
  261. <span class="n">std_original</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">std_original</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
  262. <span class="n">mean_required</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">mean_required</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
  263. <span class="n">std_required</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">std_required</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
  264. <span class="bp">self</span><span class="o">.</span><span class="n">additive</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">((</span><span class="n">mean_original</span> <span class="o">-</span> <span class="n">mean_required</span><span class="p">)</span> <span class="o">/</span> <span class="n">std_original</span><span class="p">)</span>
  265. <span class="bp">self</span><span class="o">.</span><span class="n">multiplier</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">std_original</span> <span class="o">/</span> <span class="n">std_required</span><span class="p">)</span>
  266. <div class="viewcode-block" id="NormalizationAdapter.forward"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.module_utils.NormalizationAdapter.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
  267. <span class="n">x</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">additive</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">multiplier</span>
  268. <span class="k">return</span> <span class="n">x</span></div></div>
  269. </pre></div>
  270. </div>
  271. </div>
  272. <footer>
  273. <hr/>
  274. <div role="contentinfo">
  275. <p>&#169; Copyright 2021, SuperGradients team.</p>
  276. </div>
  277. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  278. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  279. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  280. </footer>
  281. </div>
  282. </div>
  283. </section>
  284. </div>
  285. <script>
  286. jQuery(function () {
  287. SphinxRtdTheme.Navigation.enable(true);
  288. });
  289. </script>
  290. </body>
  291. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.utils.optimizer_utils &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.utils.optimizer_utils</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.utils.optimizer_utils</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">torch.optim</span> <span class="k">as</span> <span class="nn">optim</span>
  84. <span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
  85. <span class="kn">from</span> <span class="nn">torch.nn.modules.batchnorm</span> <span class="kn">import</span> <span class="n">_BatchNorm</span>
  86. <span class="kn">from</span> <span class="nn">torch.nn.modules.conv</span> <span class="kn">import</span> <span class="n">_ConvNd</span>
  87. <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
  88. <span class="kn">from</span> <span class="nn">super_gradients.common.factories.optimizers_type_factory</span> <span class="kn">import</span> <span class="n">OptimizersTypeFactory</span>
  89. <span class="kn">from</span> <span class="nn">super_gradients.training.params</span> <span class="kn">import</span> <span class="n">DEFAULT_OPTIMIZER_PARAMS_SGD</span><span class="p">,</span> <span class="n">DEFAULT_OPTIMIZER_PARAMS_ADAM</span><span class="p">,</span> \
  90. <span class="n">DEFAULT_OPTIMIZER_PARAMS_RMSPROP</span><span class="p">,</span> <span class="n">DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF</span>
  91. <span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">get_param</span>
  92. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.optimizers.rmsprop_tf</span> <span class="kn">import</span> <span class="n">RMSpropTF</span>
  93. <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
  94. <span class="n">OPTIMIZERS_DEFAULT_PARAMS</span> <span class="o">=</span> <span class="p">{</span><span class="n">optim</span><span class="o">.</span><span class="n">SGD</span><span class="p">:</span> <span class="n">DEFAULT_OPTIMIZER_PARAMS_SGD</span><span class="p">,</span>
  95. <span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">:</span> <span class="n">DEFAULT_OPTIMIZER_PARAMS_ADAM</span><span class="p">,</span>
  96. <span class="n">optim</span><span class="o">.</span><span class="n">RMSprop</span><span class="p">:</span> <span class="n">DEFAULT_OPTIMIZER_PARAMS_RMSPROP</span><span class="p">,</span>
  97. <span class="n">RMSpropTF</span><span class="p">:</span> <span class="n">DEFAULT_OPTIMIZER_PARAMS_RMSPROPTF</span><span class="p">}</span>
  98. <div class="viewcode-block" id="separate_zero_wd_params_groups_for_optimizer"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.optimizer_utils.separate_zero_wd_params_groups_for_optimizer">[docs]</a><span class="k">def</span> <span class="nf">separate_zero_wd_params_groups_for_optimizer</span><span class="p">(</span><span class="n">module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">net_named_params</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
  99. <span class="sd">&quot;&quot;&quot;</span>
  100. <span class="sd"> separate param groups for batchnorm and biases and others with weight decay. return list of param groups in format</span>
  101. <span class="sd"> required by torch Optimizer classes.</span>
  102. <span class="sd"> bias + BN with weight decay=0 and the rest with the given weight decay</span>
  103. <span class="sd"> :param module: train net module.</span>
  104. <span class="sd"> :param net_named_params: list of params groups, output of SgModule.initialize_param_groups</span>
  105. <span class="sd"> :param weight_decay: value to set for the non BN and bias parameters</span>
  106. <span class="sd"> &quot;&quot;&quot;</span>
  107. <span class="c1"># FIXME - replace usage of ids addresses to find batchnorm and biases params.</span>
  108. <span class="c1"># This solution iterate 2 times over module parameters, find a way to iterate only one time.</span>
  109. <span class="n">no_decay_ids</span> <span class="o">=</span> <span class="n">_get_no_decay_param_ids</span><span class="p">(</span><span class="n">module</span><span class="p">)</span>
  110. <span class="c1"># split param groups for optimizer</span>
  111. <span class="n">optimizer_param_groups</span> <span class="o">=</span> <span class="p">[]</span>
  112. <span class="k">for</span> <span class="n">param_group</span> <span class="ow">in</span> <span class="n">net_named_params</span><span class="p">:</span>
  113. <span class="n">no_decay_params</span> <span class="o">=</span> <span class="p">[]</span>
  114. <span class="n">decay_params</span> <span class="o">=</span> <span class="p">[]</span>
  115. <span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">param_group</span><span class="p">[</span><span class="s2">&quot;named_params&quot;</span><span class="p">]:</span>
  116. <span class="k">if</span> <span class="nb">id</span><span class="p">(</span><span class="n">param</span><span class="p">)</span> <span class="ow">in</span> <span class="n">no_decay_ids</span><span class="p">:</span>
  117. <span class="n">no_decay_params</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">param</span><span class="p">)</span>
  118. <span class="k">else</span><span class="p">:</span>
  119. <span class="n">decay_params</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">param</span><span class="p">)</span>
  120. <span class="c1"># append two param groups from the original param group, with and without weight decay.</span>
  121. <span class="n">extra_optim_params</span> <span class="o">=</span> <span class="p">{</span><span class="n">key</span><span class="p">:</span> <span class="n">param_group</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">param_group</span>
  122. <span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;named_params&quot;</span><span class="p">,</span> <span class="s2">&quot;weight_decay&quot;</span><span class="p">]}</span>
  123. <span class="n">optimizer_param_groups</span><span class="o">.</span><span class="n">append</span><span class="p">({</span><span class="s2">&quot;params&quot;</span><span class="p">:</span> <span class="n">no_decay_params</span><span class="p">,</span> <span class="s2">&quot;weight_decay&quot;</span><span class="p">:</span> <span class="mf">0.0</span><span class="p">,</span> <span class="o">**</span><span class="n">extra_optim_params</span><span class="p">})</span>
  124. <span class="n">optimizer_param_groups</span><span class="o">.</span><span class="n">append</span><span class="p">({</span><span class="s2">&quot;params&quot;</span><span class="p">:</span> <span class="n">decay_params</span><span class="p">,</span> <span class="s2">&quot;weight_decay&quot;</span><span class="p">:</span> <span class="n">weight_decay</span><span class="p">,</span> <span class="o">**</span><span class="n">extra_optim_params</span><span class="p">})</span>
  125. <span class="k">return</span> <span class="n">optimizer_param_groups</span></div>
  126. <span class="k">def</span> <span class="nf">_get_no_decay_param_ids</span><span class="p">(</span><span class="n">module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
  127. <span class="c1"># FIXME - replace usage of ids addresses to find batchnorm and biases params.</span>
  128. <span class="c1"># Use other common way to identify torch parameters other than id or layer names</span>
  129. <span class="sd">&quot;&quot;&quot;</span>
  130. <span class="sd"> Iterate over module.modules() and returns params id addresses of batch-norm and biases params.</span>
  131. <span class="sd"> NOTE - ALL MODULES WITH ATTRIBUTES NAMED BIAS AND ARE INSTANCE OF nn.Parameter WILL BE CONSIDERED A BIAS PARAM FOR</span>
  132. <span class="sd"> ZERO WEIGHT DECAY.</span>
  133. <span class="sd"> &quot;&quot;&quot;</span>
  134. <span class="n">batchnorm_types</span> <span class="o">=</span> <span class="p">(</span><span class="n">_BatchNorm</span><span class="p">,)</span>
  135. <span class="n">torch_weight_with_bias_types</span> <span class="o">=</span> <span class="p">(</span><span class="n">_ConvNd</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">)</span>
  136. <span class="n">no_decay_ids</span> <span class="o">=</span> <span class="p">[]</span>
  137. <span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">m</span> <span class="ow">in</span> <span class="n">module</span><span class="o">.</span><span class="n">named_modules</span><span class="p">():</span>
  138. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">batchnorm_types</span><span class="p">):</span>
  139. <span class="n">no_decay_ids</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">id</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">weight</span><span class="p">))</span>
  140. <span class="n">no_decay_ids</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">id</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">bias</span><span class="p">))</span>
  141. <span class="k">elif</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="s2">&quot;bias&quot;</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">):</span>
  142. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">torch_weight_with_bias_types</span><span class="p">):</span>
  143. <span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Module class: </span><span class="si">{</span><span class="n">m</span><span class="o">.</span><span class="vm">__class__</span><span class="si">}</span><span class="s2">, have a `bias` parameter attribute but is not instance of&quot;</span>
  144. <span class="sa">f</span><span class="s2">&quot; torch primitive modules, this bias parameter will be part of param group with zero&quot;</span>
  145. <span class="sa">f</span><span class="s2">&quot; weight decay.&quot;</span><span class="p">)</span>
  146. <span class="n">no_decay_ids</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">id</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">bias</span><span class="p">))</span>
  147. <span class="k">return</span> <span class="n">no_decay_ids</span>
  148. <div class="viewcode-block" id="build_optimizer"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.optimizer_utils.build_optimizer">[docs]</a><span class="k">def</span> <span class="nf">build_optimizer</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">lr</span><span class="p">,</span> <span class="n">training_params</span><span class="p">):</span>
  149. <span class="sd">&quot;&quot;&quot;</span>
  150. <span class="sd"> Wrapper function for initializing the optimizer</span>
  151. <span class="sd"> :param net: the nn_module to build the optimizer for</span>
  152. <span class="sd"> :param lr: initial learning rate</span>
  153. <span class="sd"> :param training_params: training_parameters</span>
  154. <span class="sd"> &quot;&quot;&quot;</span>
  155. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">training_params</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span>
  156. <span class="n">optimizer_cls</span> <span class="o">=</span> <span class="n">OptimizersTypeFactory</span><span class="p">()</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">training_params</span><span class="o">.</span><span class="n">optimizer</span><span class="p">)</span>
  157. <span class="k">else</span><span class="p">:</span>
  158. <span class="n">optimizer_cls</span> <span class="o">=</span> <span class="n">training_params</span><span class="o">.</span><span class="n">optimizer</span>
  159. <span class="n">default_optimizer_params</span> <span class="o">=</span> <span class="n">OPTIMIZERS_DEFAULT_PARAMS</span><span class="p">[</span><span class="n">optimizer_cls</span><span class="p">]</span> <span class="k">if</span> <span class="n">optimizer_cls</span> <span class="ow">in</span> <span class="n">OPTIMIZERS_DEFAULT_PARAMS</span> <span class="k">else</span> <span class="p">{}</span>
  160. <span class="n">training_params</span><span class="o">.</span><span class="n">optimizer_params</span> <span class="o">=</span> <span class="n">get_param</span><span class="p">(</span><span class="n">training_params</span><span class="p">,</span> <span class="s1">&#39;optimizer_params&#39;</span><span class="p">,</span> <span class="n">default_optimizer_params</span><span class="p">)</span>
  161. <span class="c1"># OPTIMIZER PARAM GROUPS ARE SET USING DEFAULT OR MODEL SPECIFIC INIT</span>
  162. <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="p">,</span> <span class="s1">&#39;initialize_param_groups&#39;</span><span class="p">):</span>
  163. <span class="c1"># INITIALIZE_PARAM_GROUPS MUST RETURN A LIST OF DICTS WITH &#39;named_params&#39; AND OPTIMIZER&#39;s ATTRIBUTES PER GROUP</span>
  164. <span class="n">net_named_params</span> <span class="o">=</span> <span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="o">.</span><span class="n">initialize_param_groups</span><span class="p">(</span><span class="n">lr</span><span class="p">,</span> <span class="n">training_params</span><span class="p">)</span>
  165. <span class="k">else</span><span class="p">:</span>
  166. <span class="n">net_named_params</span> <span class="o">=</span> <span class="p">[{</span><span class="s1">&#39;named_params&#39;</span><span class="p">:</span> <span class="n">net</span><span class="o">.</span><span class="n">named_parameters</span><span class="p">()}]</span>
  167. <span class="k">if</span> <span class="n">training_params</span><span class="o">.</span><span class="n">zero_weight_decay_on_bias_and_bn</span><span class="p">:</span>
  168. <span class="n">optimizer_training_params</span> <span class="o">=</span> <span class="n">separate_zero_wd_params_groups_for_optimizer</span><span class="p">(</span>
  169. <span class="n">net</span><span class="o">.</span><span class="n">module</span><span class="p">,</span> <span class="n">net_named_params</span><span class="p">,</span> <span class="n">training_params</span><span class="o">.</span><span class="n">optimizer_params</span><span class="p">[</span><span class="s1">&#39;weight_decay&#39;</span><span class="p">]</span>
  170. <span class="p">)</span>
  171. <span class="k">else</span><span class="p">:</span>
  172. <span class="c1"># Overwrite groups to include params instead of named params</span>
  173. <span class="k">for</span> <span class="n">ind_group</span><span class="p">,</span> <span class="n">param_group</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">net_named_params</span><span class="p">):</span>
  174. <span class="n">param_group</span><span class="p">[</span><span class="s1">&#39;params&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="n">param</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="n">param_group</span><span class="p">[</span><span class="s1">&#39;named_params&#39;</span><span class="p">])]</span>
  175. <span class="k">del</span> <span class="n">param_group</span><span class="p">[</span><span class="s1">&#39;named_params&#39;</span><span class="p">]</span>
  176. <span class="n">net_named_params</span><span class="p">[</span><span class="n">ind_group</span><span class="p">]</span> <span class="o">=</span> <span class="n">param_group</span>
  177. <span class="n">optimizer_training_params</span> <span class="o">=</span> <span class="n">net_named_params</span>
  178. <span class="c1"># CREATE AN OPTIMIZER OBJECT AND INITIALIZE IT</span>
  179. <span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer_cls</span><span class="p">(</span><span class="n">optimizer_training_params</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">,</span> <span class="o">**</span><span class="n">training_params</span><span class="o">.</span><span class="n">optimizer_params</span><span class="p">)</span>
  180. <span class="k">return</span> <span class="n">optimizer</span></div>
  181. </pre></div>
  182. </div>
  183. </div>
  184. <footer>
  185. <hr/>
  186. <div role="contentinfo">
  187. <p>&#169; Copyright 2021, SuperGradients team.</p>
  188. </div>
  189. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  190. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  191. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  192. </footer>
  193. </div>
  194. </div>
  195. </section>
  196. </div>
  197. <script>
  198. jQuery(function () {
  199. SphinxRtdTheme.Navigation.enable(true);
  200. });
  201. </script>
  202. </body>
  203. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.utils.optimizers.rmsprop_tf &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../../" id="documentation_options" src="../../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../../_static/jquery.js"></script>
  15. <script src="../../../../../_static/underscore.js"></script>
  16. <script src="../../../../../_static/doctools.js"></script>
  17. <script src="../../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.utils.optimizers.rmsprop_tf</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.utils.optimizers.rmsprop_tf</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">torch</span>
  84. <span class="kn">from</span> <span class="nn">torch.optim</span> <span class="kn">import</span> <span class="n">Optimizer</span>
  85. <span class="sd">&quot;&quot;&quot;</span>
  86. <span class="sd">This implementation is taken from timm&#39;s github:</span>
  87. <span class="sd">https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/rmsprop_tf.py</span>
  88. <span class="sd">&quot;&quot;&quot;</span>
  89. <span class="sd">&quot;&quot;&quot; RMSProp modified to behave like Tensorflow impl</span>
  90. <span class="sd">Originally cut &amp; paste from PyTorch RMSProp</span>
  91. <span class="sd">https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py</span>
  92. <span class="sd">Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE</span>
  93. <span class="sd">Modifications Copyright 2020 Ross Wightman</span>
  94. <span class="sd">&quot;&quot;&quot;</span>
  95. <div class="viewcode-block" id="RMSpropTF"><a class="viewcode-back" href="../../../../../super_gradients.training.utils.optimizers.html#super_gradients.training.utils.optimizers.rmsprop_tf.RMSpropTF">[docs]</a><span class="k">class</span> <span class="nc">RMSpropTF</span><span class="p">(</span><span class="n">Optimizer</span><span class="p">):</span>
  96. <span class="sd">&quot;&quot;&quot;Implements RMSprop algorithm (TensorFlow style epsilon)</span>
  97. <span class="sd"> NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt</span>
  98. <span class="sd"> and a few other modifications to closer match Tensorflow for matching hyper-params.</span>
  99. <span class="sd"> Noteworthy changes include:</span>
  100. <span class="sd"> 1. Epsilon applied inside square-root</span>
  101. <span class="sd"> 2. square_avg initialized to ones</span>
  102. <span class="sd"> 3. LR scaling of update accumulated in momentum buffer</span>
  103. <span class="sd"> Proposed by G. Hinton in his</span>
  104. <span class="sd"> `course &lt;http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf&gt;`_.</span>
  105. <span class="sd"> The centered version first appears in `Generating Sequences</span>
  106. <span class="sd"> With Recurrent Neural Networks &lt;https://arxiv.org/pdf/1308.0850v5.pdf&gt;`_.&quot;&quot;&quot;</span>
  107. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-10</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">centered</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  108. <span class="n">decoupled_decay</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">lr_in_momentum</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
  109. <span class="sd">&quot;&quot;&quot;RMSprop optimizer that follows the tf&#39;s RMSprop characteristics</span>
  110. <span class="sd"> :param params (iterable): iterable of parameters to optimize or dicts defining parameter groups</span>
  111. <span class="sd"> :param lr (float, optional): learning rate</span>
  112. <span class="sd"> :param momentum (float, optional): momentum factor</span>
  113. <span class="sd"> :param alpha (float, optional): smoothing (decay) constant</span>
  114. <span class="sd"> :param eps (float, optional): term added to the denominator to improve numerical stability</span>
  115. <span class="sd"> :param centered (bool, optional) : if ``True``, compute the centered RMSProp, the gradient is normalized by an</span>
  116. <span class="sd"> estimation of its variance</span>
  117. <span class="sd"> :param weight_decay (float, optional): weight decay (L2 penalty)</span>
  118. <span class="sd"> :param decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101</span>
  119. <span class="sd"> :param lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer update as per</span>
  120. <span class="sd"> defaults in Tensorflow</span>
  121. <span class="sd"> &quot;&quot;&quot;</span>
  122. <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">lr</span><span class="p">:</span>
  123. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Invalid learning rate: </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">lr</span><span class="p">))</span>
  124. <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">eps</span><span class="p">:</span>
  125. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Invalid epsilon value: </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">eps</span><span class="p">))</span>
  126. <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">momentum</span><span class="p">:</span>
  127. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Invalid momentum value: </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">momentum</span><span class="p">))</span>
  128. <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">weight_decay</span><span class="p">:</span>
  129. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Invalid weight_decay value: </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">weight_decay</span><span class="p">))</span>
  130. <span class="k">if</span> <span class="ow">not</span> <span class="mf">0.0</span> <span class="o">&lt;=</span> <span class="n">alpha</span><span class="p">:</span>
  131. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Invalid alpha value: </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">alpha</span><span class="p">))</span>
  132. <span class="n">defaults</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="n">lr</span><span class="o">=</span><span class="n">lr</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="n">momentum</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="n">alpha</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">eps</span><span class="p">,</span> <span class="n">centered</span><span class="o">=</span><span class="n">centered</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">weight_decay</span><span class="p">,</span>
  133. <span class="n">decoupled_decay</span><span class="o">=</span><span class="n">decoupled_decay</span><span class="p">,</span> <span class="n">lr_in_momentum</span><span class="o">=</span><span class="n">lr_in_momentum</span><span class="p">)</span>
  134. <span class="nb">super</span><span class="p">(</span><span class="n">RMSpropTF</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">defaults</span><span class="p">)</span>
  135. <span class="k">def</span> <span class="nf">__setstate__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">):</span>
  136. <span class="nb">super</span><span class="p">(</span><span class="n">RMSpropTF</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">__setstate__</span><span class="p">(</span><span class="n">state</span><span class="p">)</span>
  137. <span class="k">for</span> <span class="n">group</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span>
  138. <span class="n">group</span><span class="o">.</span><span class="n">setdefault</span><span class="p">(</span><span class="s1">&#39;momentum&#39;</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span>
  139. <span class="n">group</span><span class="o">.</span><span class="n">setdefault</span><span class="p">(</span><span class="s1">&#39;centered&#39;</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
  140. <div class="viewcode-block" id="RMSpropTF.step"><a class="viewcode-back" href="../../../../../super_gradients.training.utils.optimizers.html#super_gradients.training.utils.optimizers.rmsprop_tf.RMSpropTF.step">[docs]</a> <span class="k">def</span> <span class="nf">step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">closure</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="c1"># noqa: C901</span>
  141. <span class="sd">&quot;&quot;&quot;Performs a single optimization step.</span>
  142. <span class="sd"> Arguments:</span>
  143. <span class="sd"> closure (callable, optional): A closure that reevaluates the model</span>
  144. <span class="sd"> and returns the loss.</span>
  145. <span class="sd"> &quot;&quot;&quot;</span>
  146. <span class="n">loss</span> <span class="o">=</span> <span class="kc">None</span>
  147. <span class="k">if</span> <span class="n">closure</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  148. <span class="n">loss</span> <span class="o">=</span> <span class="n">closure</span><span class="p">()</span>
  149. <span class="k">for</span> <span class="n">group</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">param_groups</span><span class="p">:</span>
  150. <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;params&#39;</span><span class="p">]:</span>
  151. <span class="k">if</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  152. <span class="k">continue</span>
  153. <span class="n">grad</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span>
  154. <span class="k">if</span> <span class="n">grad</span><span class="o">.</span><span class="n">is_sparse</span><span class="p">:</span>
  155. <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s1">&#39;RMSprop does not support sparse gradients&#39;</span><span class="p">)</span>
  156. <span class="n">state</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">state</span><span class="p">[</span><span class="n">p</span><span class="p">]</span>
  157. <span class="c1"># State initialization</span>
  158. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">state</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  159. <span class="n">state</span><span class="p">[</span><span class="s1">&#39;step&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
  160. <span class="n">state</span><span class="p">[</span><span class="s1">&#39;square_avg&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="p">)</span> <span class="c1"># PyTorch inits to zero</span>
  161. <span class="k">if</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;momentum&#39;</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
  162. <span class="n">state</span><span class="p">[</span><span class="s1">&#39;momentum_buffer&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
  163. <span class="k">if</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;centered&#39;</span><span class="p">]:</span>
  164. <span class="n">state</span><span class="p">[</span><span class="s1">&#39;grad_avg&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
  165. <span class="n">square_avg</span> <span class="o">=</span> <span class="n">state</span><span class="p">[</span><span class="s1">&#39;square_avg&#39;</span><span class="p">]</span>
  166. <span class="n">one_minus_alpha</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">-</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;alpha&#39;</span><span class="p">]</span>
  167. <span class="n">state</span><span class="p">[</span><span class="s1">&#39;step&#39;</span><span class="p">]</span> <span class="o">+=</span> <span class="mi">1</span>
  168. <span class="k">if</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;weight_decay&#39;</span><span class="p">]</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span>
  169. <span class="k">if</span> <span class="s1">&#39;decoupled_decay&#39;</span> <span class="ow">in</span> <span class="n">group</span> <span class="ow">and</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;decoupled_decay&#39;</span><span class="p">]:</span>
  170. <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">add_</span><span class="p">(</span><span class="o">-</span><span class="n">group</span><span class="p">[</span><span class="s1">&#39;weight_decay&#39;</span><span class="p">],</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
  171. <span class="k">else</span><span class="p">:</span>
  172. <span class="n">grad</span> <span class="o">=</span> <span class="n">grad</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">group</span><span class="p">[</span><span class="s1">&#39;weight_decay&#39;</span><span class="p">],</span> <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="p">)</span>
  173. <span class="c1"># Tensorflow order of ops for updating squared avg</span>
  174. <span class="n">square_avg</span><span class="o">.</span><span class="n">add_</span><span class="p">(</span><span class="n">one_minus_alpha</span><span class="p">,</span> <span class="n">grad</span><span class="o">.</span><span class="n">pow</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span> <span class="o">-</span> <span class="n">square_avg</span><span class="p">)</span>
  175. <span class="c1"># square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original</span>
  176. <span class="k">if</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;centered&#39;</span><span class="p">]:</span>
  177. <span class="n">grad_avg</span> <span class="o">=</span> <span class="n">state</span><span class="p">[</span><span class="s1">&#39;grad_avg&#39;</span><span class="p">]</span>
  178. <span class="n">grad_avg</span><span class="o">.</span><span class="n">add_</span><span class="p">(</span><span class="n">one_minus_alpha</span><span class="p">,</span> <span class="n">grad</span> <span class="o">-</span> <span class="n">grad_avg</span><span class="p">)</span>
  179. <span class="c1"># grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original</span>
  180. <span class="n">avg</span> <span class="o">=</span> <span class="n">square_avg</span><span class="o">.</span><span class="n">addcmul</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">grad_avg</span><span class="p">,</span> <span class="n">grad_avg</span><span class="p">)</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">group</span><span class="p">[</span><span class="s1">&#39;eps&#39;</span><span class="p">])</span><span class="o">.</span><span class="n">sqrt_</span><span class="p">()</span> <span class="c1"># eps moved in sqrt</span>
  181. <span class="k">else</span><span class="p">:</span>
  182. <span class="n">avg</span> <span class="o">=</span> <span class="n">square_avg</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">group</span><span class="p">[</span><span class="s1">&#39;eps&#39;</span><span class="p">])</span><span class="o">.</span><span class="n">sqrt_</span><span class="p">()</span> <span class="c1"># eps moved in sqrt</span>
  183. <span class="k">if</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;momentum&#39;</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
  184. <span class="n">buf</span> <span class="o">=</span> <span class="n">state</span><span class="p">[</span><span class="s1">&#39;momentum_buffer&#39;</span><span class="p">]</span>
  185. <span class="c1"># Tensorflow accumulates the LR scaling in the momentum buffer</span>
  186. <span class="k">if</span> <span class="s1">&#39;lr_in_momentum&#39;</span> <span class="ow">in</span> <span class="n">group</span> <span class="ow">and</span> <span class="n">group</span><span class="p">[</span><span class="s1">&#39;lr_in_momentum&#39;</span><span class="p">]:</span>
  187. <span class="n">buf</span><span class="o">.</span><span class="n">mul_</span><span class="p">(</span><span class="n">group</span><span class="p">[</span><span class="s1">&#39;momentum&#39;</span><span class="p">])</span><span class="o">.</span><span class="n">addcdiv_</span><span class="p">(</span><span class="n">group</span><span class="p">[</span><span class="s1">&#39;lr&#39;</span><span class="p">],</span> <span class="n">grad</span><span class="p">,</span> <span class="n">avg</span><span class="p">)</span>
  188. <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">add_</span><span class="p">(</span><span class="o">-</span><span class="n">buf</span><span class="p">)</span>
  189. <span class="k">else</span><span class="p">:</span>
  190. <span class="c1"># PyTorch scales the param update by LR</span>
  191. <span class="n">buf</span><span class="o">.</span><span class="n">mul_</span><span class="p">(</span><span class="n">group</span><span class="p">[</span><span class="s1">&#39;momentum&#39;</span><span class="p">])</span><span class="o">.</span><span class="n">addcdiv_</span><span class="p">(</span><span class="n">grad</span><span class="p">,</span> <span class="n">avg</span><span class="p">)</span>
  192. <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">add_</span><span class="p">(</span><span class="o">-</span><span class="n">group</span><span class="p">[</span><span class="s1">&#39;lr&#39;</span><span class="p">],</span> <span class="n">buf</span><span class="p">)</span>
  193. <span class="k">else</span><span class="p">:</span>
  194. <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">addcdiv_</span><span class="p">(</span><span class="o">-</span><span class="n">group</span><span class="p">[</span><span class="s1">&#39;lr&#39;</span><span class="p">],</span> <span class="n">grad</span><span class="p">,</span> <span class="n">avg</span><span class="p">)</span>
  195. <span class="k">return</span> <span class="n">loss</span></div></div>
  196. </pre></div>
  197. </div>
  198. </div>
  199. <footer>
  200. <hr/>
  201. <div role="contentinfo">
  202. <p>&#169; Copyright 2021, SuperGradients team.</p>
  203. </div>
  204. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  205. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  206. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  207. </footer>
  208. </div>
  209. </div>
  210. </section>
  211. </div>
  212. <script>
  213. jQuery(function () {
  214. SphinxRtdTheme.Navigation.enable(true);
  215. });
  216. </script>
  217. </body>
  218. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.utils.regularization_utils &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.utils.regularization_utils</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.utils.regularization_utils</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">torch</span>
  84. <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
  85. <div class="viewcode-block" id="DropPath"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.regularization_utils.DropPath">[docs]</a><span class="k">class</span> <span class="nc">DropPath</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
  86. <span class="sd">&quot;&quot;&quot;</span>
  87. <span class="sd"> Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).</span>
  88. <span class="sd"> Code taken from TIMM (https://github.com/rwightman/pytorch-image-models)</span>
  89. <span class="sd"> Apache License 2.0</span>
  90. <span class="sd"> &quot;&quot;&quot;</span>
  91. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">drop_prob</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  92. <span class="nb">super</span><span class="p">(</span><span class="n">DropPath</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
  93. <span class="bp">self</span><span class="o">.</span><span class="n">drop_prob</span> <span class="o">=</span> <span class="n">drop_prob</span>
  94. <div class="viewcode-block" id="DropPath.forward"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.regularization_utils.DropPath.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
  95. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop_prob</span> <span class="o">==</span> <span class="mf">0.</span> <span class="ow">or</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span><span class="p">:</span>
  96. <span class="k">return</span> <span class="n">x</span>
  97. <span class="n">keep_prob</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop_prob</span>
  98. <span class="n">shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],)</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span><span class="p">,)</span> <span class="o">*</span> <span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">ndim</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="c1"># work with diff dim tensors, not just 2D ConvNets</span>
  99. <span class="n">random_tensor</span> <span class="o">=</span> <span class="n">keep_prob</span> <span class="o">+</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">shape</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
  100. <span class="n">random_tensor</span><span class="o">.</span><span class="n">floor_</span><span class="p">()</span> <span class="c1"># binarize</span>
  101. <span class="n">output</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">div</span><span class="p">(</span><span class="n">keep_prob</span><span class="p">)</span> <span class="o">*</span> <span class="n">random_tensor</span>
  102. <span class="k">return</span> <span class="n">output</span></div></div>
  103. </pre></div>
  104. </div>
  105. </div>
  106. <footer>
  107. <hr/>
  108. <div role="contentinfo">
  109. <p>&#169; Copyright 2021, SuperGradients team.</p>
  110. </div>
  111. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  112. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  113. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  114. </footer>
  115. </div>
  116. </div>
  117. </section>
  118. </div>
  119. <script>
  120. jQuery(function () {
  121. SphinxRtdTheme.Navigation.enable(true);
  122. });
  123. </script>
  124. </body>
  125. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.utils.segmentation_utils &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.utils.segmentation_utils</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.utils.segmentation_utils</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">os</span>
  84. <span class="kn">import</span> <span class="nn">cv2</span>
  85. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  86. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Callable</span>
  87. <span class="kn">import</span> <span class="nn">torch</span>
  88. <span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
  89. <span class="kn">from</span> <span class="nn">torchvision.utils</span> <span class="kn">import</span> <span class="n">draw_segmentation_masks</span>
  90. <span class="c1"># FIXME: REFACTOR AUGMENTATIONS, CONSIDER USING A MORE EFFICIENT LIBRARIES SUCH AS, IMGAUG, DALI ETC.</span>
  91. <span class="kn">from</span> <span class="nn">super_gradients.training</span> <span class="kn">import</span> <span class="n">utils</span> <span class="k">as</span> <span class="n">core_utils</span>
  92. <div class="viewcode-block" id="coco_sub_classes_inclusion_tuples_list"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.segmentation_utils.coco_sub_classes_inclusion_tuples_list">[docs]</a><span class="k">def</span> <span class="nf">coco_sub_classes_inclusion_tuples_list</span><span class="p">():</span>
  93. <span class="k">return</span> <span class="p">[(</span><span class="mi">0</span><span class="p">,</span> <span class="s1">&#39;background&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="s1">&#39;airplane&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="s1">&#39;bicycle&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="s1">&#39;bird&#39;</span><span class="p">),</span>
  94. <span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="s1">&#39;boat&#39;</span><span class="p">),</span>
  95. <span class="p">(</span><span class="mi">44</span><span class="p">,</span> <span class="s1">&#39;bottle&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">6</span><span class="p">,</span> <span class="s1">&#39;bus&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="s1">&#39;car&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">17</span><span class="p">,</span> <span class="s1">&#39;cat&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">62</span><span class="p">,</span> <span class="s1">&#39;chair&#39;</span><span class="p">),</span>
  96. <span class="p">(</span><span class="mi">21</span><span class="p">,</span> <span class="s1">&#39;cow&#39;</span><span class="p">),</span>
  97. <span class="p">(</span><span class="mi">67</span><span class="p">,</span> <span class="s1">&#39;dining table&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">18</span><span class="p">,</span> <span class="s1">&#39;dog&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">19</span><span class="p">,</span> <span class="s1">&#39;horse&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="s1">&#39;motorcycle&#39;</span><span class="p">),</span>
  98. <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="s1">&#39;person&#39;</span><span class="p">),</span>
  99. <span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="s1">&#39;potted plant&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="s1">&#39;sheep&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">63</span><span class="p">,</span> <span class="s1">&#39;couch&#39;</span><span class="p">),</span> <span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="s1">&#39;train&#39;</span><span class="p">),</span>
  100. <span class="p">(</span><span class="mi">72</span><span class="p">,</span> <span class="s1">&#39;tv&#39;</span><span class="p">)]</span></div>
  101. <div class="viewcode-block" id="to_one_hot"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.segmentation_utils.to_one_hot">[docs]</a><span class="k">def</span> <span class="nf">to_one_hot</span><span class="p">(</span><span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">ignore_index</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  102. <span class="sd">&quot;&quot;&quot;</span>
  103. <span class="sd"> Target label to one_hot tensor. labels and ignore_index must be consecutive numbers.</span>
  104. <span class="sd"> :param target: Class labels long tensor, with shape [N, H, W]</span>
  105. <span class="sd"> :param num_classes: num of classes in datasets excluding ignore label, this is the output channels of the one hot</span>
  106. <span class="sd"> result.</span>
  107. <span class="sd"> :return: one hot tensor with shape [N, num_classes, H, W]</span>
  108. <span class="sd"> &quot;&quot;&quot;</span>
  109. <span class="n">num_classes</span> <span class="o">=</span> <span class="n">num_classes</span> <span class="k">if</span> <span class="n">ignore_index</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">num_classes</span> <span class="o">+</span> <span class="mi">1</span>
  110. <span class="n">one_hot</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">target</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span><span class="o">.</span><span class="n">permute</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span>
  111. <span class="k">if</span> <span class="n">ignore_index</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  112. <span class="c1"># remove ignore_index channel</span>
  113. <span class="n">one_hot</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">one_hot</span><span class="p">[:,</span> <span class="p">:</span><span class="n">ignore_index</span><span class="p">],</span> <span class="n">one_hot</span><span class="p">[:,</span> <span class="n">ignore_index</span> <span class="o">+</span> <span class="mi">1</span><span class="p">:]],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
  114. <span class="k">return</span> <span class="n">one_hot</span></div>
  115. <div class="viewcode-block" id="reverse_imagenet_preprocessing"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.segmentation_utils.reverse_imagenet_preprocessing">[docs]</a><span class="k">def</span> <span class="nf">reverse_imagenet_preprocessing</span><span class="p">(</span><span class="n">im_tensor</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
  116. <span class="sd">&quot;&quot;&quot;</span>
  117. <span class="sd"> :param im_tensor: images in a batch after preprocessing for inference, RGB, (B, C, H, W)</span>
  118. <span class="sd"> :return: images in a batch in cv2 format, BGR, (B, H, W, C)</span>
  119. <span class="sd"> &quot;&quot;&quot;</span>
  120. <span class="n">im_np</span> <span class="o">=</span> <span class="n">im_tensor</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
  121. <span class="n">im_np</span> <span class="o">=</span> <span class="n">im_np</span><span class="p">[:,</span> <span class="p">::</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
  122. <span class="n">im_np</span> <span class="o">*=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[[</span><span class="mf">.229</span><span class="p">,</span> <span class="mf">.224</span><span class="p">,</span> <span class="mf">.225</span><span class="p">][::</span><span class="o">-</span><span class="mi">1</span><span class="p">]]])</span>
  123. <span class="n">im_np</span> <span class="o">+=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[[</span><span class="mf">.485</span><span class="p">,</span> <span class="mf">.456</span><span class="p">,</span> <span class="mf">.406</span><span class="p">][::</span><span class="o">-</span><span class="mi">1</span><span class="p">]]])</span>
  124. <span class="n">im_np</span> <span class="o">*=</span> <span class="mf">255.</span>
  125. <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">ascontiguousarray</span><span class="p">(</span><span class="n">im_np</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span></div>
  126. <div class="viewcode-block" id="BinarySegmentationVisualization"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.segmentation_utils.BinarySegmentationVisualization">[docs]</a><span class="k">class</span> <span class="nc">BinarySegmentationVisualization</span><span class="p">:</span>
  127. <span class="nd">@staticmethod</span>
  128. <span class="k">def</span> <span class="nf">_visualize_image</span><span class="p">(</span><span class="n">image_np</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">pred_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">target_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
  129. <span class="n">image_scale</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">checkpoint_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">image_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
  130. <span class="n">pred_mask</span> <span class="o">=</span> <span class="n">pred_mask</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
  131. <span class="n">image_np</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">moveaxis</span><span class="p">(</span><span class="n">image_np</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">))</span>
  132. <span class="n">pred_mask</span> <span class="o">=</span> <span class="n">pred_mask</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">newaxis</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span> <span class="o">&gt;</span> <span class="mf">0.5</span>
  133. <span class="n">target_mask</span> <span class="o">=</span> <span class="n">target_mask</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">newaxis</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">bool</span><span class="p">)</span>
  134. <span class="n">tp_mask</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">logical_and</span><span class="p">(</span><span class="n">pred_mask</span><span class="p">,</span> <span class="n">target_mask</span><span class="p">)</span>
  135. <span class="n">fp_mask</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">logical_and</span><span class="p">(</span><span class="n">pred_mask</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">logical_not</span><span class="p">(</span><span class="n">target_mask</span><span class="p">))</span>
  136. <span class="n">fn_mask</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">logical_and</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">logical_not</span><span class="p">(</span><span class="n">pred_mask</span><span class="p">),</span> <span class="n">target_mask</span><span class="p">)</span>
  137. <span class="n">overlay</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">tp_mask</span><span class="p">,</span> <span class="n">fp_mask</span><span class="p">,</span> <span class="n">fn_mask</span><span class="p">]))</span>
  138. <span class="c1"># SWITCH BETWEEN BLUE AND RED IF WE SAVE THE IMAGE ON THE DISC AS OTHERWISE WE CHANGE CHANNEL ORDERING</span>
  139. <span class="n">colors</span> <span class="o">=</span> <span class="p">[</span><span class="s1">&#39;green&#39;</span><span class="p">,</span> <span class="s1">&#39;red&#39;</span><span class="p">,</span> <span class="s1">&#39;blue&#39;</span><span class="p">]</span>
  140. <span class="n">res_image</span> <span class="o">=</span> <span class="n">draw_segmentation_masks</span><span class="p">(</span><span class="n">image_np</span><span class="p">,</span> <span class="n">overlay</span><span class="p">,</span> <span class="n">colors</span><span class="o">=</span><span class="n">colors</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
  141. <span class="n">res_image</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">res_image</span><span class="p">[</span><span class="n">ch</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="n">np</span><span class="o">.</span><span class="n">newaxis</span><span class="p">]</span> <span class="k">for</span> <span class="n">ch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">3</span><span class="p">)],</span> <span class="mi">2</span><span class="p">)</span>
  142. <span class="n">res_image</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">res_image</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="n">fx</span><span class="o">=</span><span class="n">image_scale</span><span class="p">,</span> <span class="n">fy</span><span class="o">=</span><span class="n">image_scale</span><span class="p">,</span>
  143. <span class="n">interpolation</span><span class="o">=</span><span class="n">cv2</span><span class="o">.</span><span class="n">INTER_NEAREST</span><span class="p">)</span>
  144. <span class="k">if</span> <span class="n">checkpoint_dir</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  145. <span class="k">return</span> <span class="n">res_image</span>
  146. <span class="k">else</span><span class="p">:</span>
  147. <span class="n">cv2</span><span class="o">.</span><span class="n">imwrite</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="n">image_name</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;.jpg&#39;</span><span class="p">),</span> <span class="n">res_image</span><span class="p">)</span>
  148. <div class="viewcode-block" id="BinarySegmentationVisualization.visualize_batch"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.segmentation_utils.BinarySegmentationVisualization.visualize_batch">[docs]</a> <span class="nd">@staticmethod</span>
  149. <span class="k">def</span> <span class="nf">visualize_batch</span><span class="p">(</span><span class="n">image_tensor</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">pred_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">target_mask</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
  150. <span class="n">batch_name</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">str</span><span class="p">],</span> <span class="n">checkpoint_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
  151. <span class="n">undo_preprocessing_func</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">]</span> <span class="o">=</span> <span class="n">reverse_imagenet_preprocessing</span><span class="p">,</span>
  152. <span class="n">image_scale</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.</span><span class="p">):</span>
  153. <span class="sd">&quot;&quot;&quot;</span>
  154. <span class="sd"> A helper function to visualize detections predicted by a network:</span>
  155. <span class="sd"> saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call.</span>
  156. <span class="sd"> Colors are generated on the fly: uniformly sampled from color wheel to support all given classes.</span>
  157. <span class="sd"> :param image_tensor: rgb images, (B, H, W, 3)</span>
  158. <span class="sd"> :param pred_boxes: boxes after NMS for each image in a batch, each (Num_boxes, 6),</span>
  159. <span class="sd"> values on dim 1 are: x1, y1, x2, y2, confidence, class</span>
  160. <span class="sd"> :param target_boxes: (Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h</span>
  161. <span class="sd"> (coordinates scaled to [0, 1])</span>
  162. <span class="sd"> :param batch_name: id of the current batch to use for image naming</span>
  163. <span class="sd"> :param checkpoint_dir: a path where images with boxes will be saved. if None, the result images will</span>
  164. <span class="sd"> be returns as a list of numpy image arrays</span>
  165. <span class="sd"> :param undo_preprocessing_func: a function to convert preprocessed images tensor into a batch of cv2-like images</span>
  166. <span class="sd"> :param image_scale: scale factor for output image</span>
  167. <span class="sd"> &quot;&quot;&quot;</span>
  168. <span class="n">image_np</span> <span class="o">=</span> <span class="n">undo_preprocessing_func</span><span class="p">(</span><span class="n">image_tensor</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span>
  169. <span class="n">pred_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">pred_mask</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:])</span> <span class="c1"># comment out</span>
  170. <span class="n">out_images</span> <span class="o">=</span> <span class="p">[]</span>
  171. <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">image_np</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
  172. <span class="n">preds</span> <span class="o">=</span> <span class="n">pred_mask</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
  173. <span class="n">targets</span> <span class="o">=</span> <span class="n">target_mask</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
  174. <span class="n">image_name</span> <span class="o">=</span> <span class="s1">&#39;_&#39;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="nb">str</span><span class="p">(</span><span class="n">batch_name</span><span class="p">),</span> <span class="nb">str</span><span class="p">(</span><span class="n">i</span><span class="p">)])</span>
  175. <span class="n">res_image</span> <span class="o">=</span> <span class="n">BinarySegmentationVisualization</span><span class="o">.</span><span class="n">_visualize_image</span><span class="p">(</span><span class="n">image_np</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">preds</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">image_scale</span><span class="p">,</span>
  176. <span class="n">checkpoint_dir</span><span class="p">,</span> <span class="n">image_name</span><span class="p">)</span>
  177. <span class="k">if</span> <span class="n">res_image</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  178. <span class="n">out_images</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">res_image</span><span class="p">)</span>
  179. <span class="k">return</span> <span class="n">out_images</span></div></div>
  180. <div class="viewcode-block" id="visualize_batches"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.segmentation_utils.visualize_batches">[docs]</a><span class="k">def</span> <span class="nf">visualize_batches</span><span class="p">(</span><span class="n">dataloader</span><span class="p">,</span> <span class="n">module</span><span class="p">,</span> <span class="n">visualization_path</span><span class="p">,</span> <span class="n">num_batches</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">undo_preprocessing_func</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  181. <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">visualization_path</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  182. <span class="k">for</span> <span class="n">batch_i</span><span class="p">,</span> <span class="p">(</span><span class="n">imgs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">dataloader</span><span class="p">):</span>
  183. <span class="k">if</span> <span class="n">batch_i</span> <span class="o">==</span> <span class="n">num_batches</span><span class="p">:</span>
  184. <span class="k">return</span>
  185. <span class="n">imgs</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">tensor_container_to_device</span><span class="p">(</span><span class="n">imgs</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cuda:0&#39;</span><span class="p">))</span>
  186. <span class="n">targets</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">tensor_container_to_device</span><span class="p">(</span><span class="n">targets</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;cuda:0&#39;</span><span class="p">))</span>
  187. <span class="n">pred_mask</span> <span class="o">=</span> <span class="n">module</span><span class="p">(</span><span class="n">imgs</span><span class="p">)</span>
  188. <span class="c1"># Visualize the batch</span>
  189. <span class="k">if</span> <span class="n">undo_preprocessing_func</span><span class="p">:</span>
  190. <span class="n">BinarySegmentationVisualization</span><span class="o">.</span><span class="n">visualize_batch</span><span class="p">(</span><span class="n">imgs</span><span class="p">,</span> <span class="n">pred_mask</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">batch_i</span><span class="p">,</span> <span class="n">visualization_path</span><span class="p">,</span>
  191. <span class="n">undo_preprocessing_func</span><span class="o">=</span><span class="n">undo_preprocessing_func</span><span class="p">)</span>
  192. <span class="k">else</span><span class="p">:</span>
  193. <span class="n">BinarySegmentationVisualization</span><span class="o">.</span><span class="n">visualize_batch</span><span class="p">(</span><span class="n">imgs</span><span class="p">,</span> <span class="n">pred_mask</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">batch_i</span><span class="p">,</span> <span class="n">visualization_path</span><span class="p">)</span></div>
  194. <div class="viewcode-block" id="one_hot_to_binary_edge"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.segmentation_utils.one_hot_to_binary_edge">[docs]</a><span class="k">def</span> <span class="nf">one_hot_to_binary_edge</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
  195. <span class="n">kernel_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
  196. <span class="n">flatten_channels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
  197. <span class="sd">&quot;&quot;&quot;</span>
  198. <span class="sd"> Utils function to create edge feature maps.</span>
  199. <span class="sd"> :param x: input tensor, must be one_hot tensor with shape [B, C, H, W]</span>
  200. <span class="sd"> :param kernel_size: kernel size of dilation erosion convolutions. The result edge widths depends on this argument as</span>
  201. <span class="sd"> follows: `edge_width = kernel - 1`</span>
  202. <span class="sd"> :param flatten_channels: Whether to apply logical_or across channels dimension, if at least one pixel class is</span>
  203. <span class="sd"> considered as edge pixel flatten value is 1. If set as `False` the output tensor shape is [B, C, H, W], else</span>
  204. <span class="sd"> [B, 1, H, W]. Default is `True`.</span>
  205. <span class="sd"> :return: one_hot edge torch.Tensor.</span>
  206. <span class="sd"> &quot;&quot;&quot;</span>
  207. <span class="k">if</span> <span class="n">kernel_size</span> <span class="o">&lt;</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">kernel_size</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  208. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;kernel size must be an odd positive values, such as [1, 3, 5, ..], found: </span><span class="si">{</span><span class="n">kernel_size</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
  209. <span class="n">_kernel</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
  210. <span class="n">padding</span> <span class="o">=</span> <span class="p">(</span><span class="n">kernel_size</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="mi">2</span>
  211. <span class="c1"># Use replicate padding to prevent class shifting and edge formation at the image boundaries.</span>
  212. <span class="n">padded_x</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">float</span><span class="p">(),</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;replicate&quot;</span><span class="p">,</span> <span class="n">pad</span><span class="o">=</span><span class="p">[</span><span class="n">padding</span><span class="p">]</span> <span class="o">*</span> <span class="mi">4</span><span class="p">)</span>
  213. <span class="c1"># The binary edges feature map is created by subtracting dilated features from erosed features.</span>
  214. <span class="c1"># First the positive one value masks are expanded (dilation) by applying a sliding window filter of one values.</span>
  215. <span class="c1"># The resulted output is then clamped to binary format to [0, 1], this way the one-hot boundaries are expanded by</span>
  216. <span class="c1"># (kernel_size - 1) / 2.</span>
  217. <span class="n">dilation</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span>
  218. <span class="n">F</span><span class="o">.</span><span class="n">conv2d</span><span class="p">(</span><span class="n">padded_x</span><span class="p">,</span> <span class="n">_kernel</span><span class="p">,</span> <span class="n">groups</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)),</span>
  219. <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span>
  220. <span class="p">)</span>
  221. <span class="c1"># Similar to dilation, erosion (can be seen as inverse of dilation) is applied to contract the one-hot features by</span>
  222. <span class="c1"># applying a dilation operation on the inverse of the one-hot features.</span>
  223. <span class="n">erosion</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span>
  224. <span class="n">F</span><span class="o">.</span><span class="n">conv2d</span><span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">padded_x</span><span class="p">,</span> <span class="n">_kernel</span><span class="p">,</span> <span class="n">groups</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)),</span>
  225. <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span>
  226. <span class="p">)</span>
  227. <span class="c1"># Finally the edge features are the result of subtracting dilation by erosion.</span>
  228. <span class="c1"># i.e for a simple 1D one-hot input: [0, 0, 0, 1, 1, 1, 0, 0, 0], using sliding kernel with size 3: [1, 1, 1]</span>
  229. <span class="c1"># Dilated features: [0, 0, 1, 1, 1, 1, 1, 0, 0]</span>
  230. <span class="c1"># Erosed inverse features: [0, 0, 0, 0, 1, 0, 0, 0, 0]</span>
  231. <span class="c1"># Edge features: dilation - erosion: [0, 0, 1, 1, 0, 1, 1, 0, 0]</span>
  232. <span class="n">edge</span> <span class="o">=</span> <span class="n">dilation</span> <span class="o">-</span> <span class="n">erosion</span>
  233. <span class="k">if</span> <span class="n">flatten_channels</span><span class="p">:</span>
  234. <span class="c1"># use max operator across channels. Equivalent to logical or for input with binary values [0, 1].</span>
  235. <span class="n">edge</span> <span class="o">=</span> <span class="n">edge</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="kc">True</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
  236. <span class="k">return</span> <span class="n">edge</span></div>
  237. <div class="viewcode-block" id="target_to_binary_edge"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.segmentation_utils.target_to_binary_edge">[docs]</a><span class="k">def</span> <span class="nf">target_to_binary_edge</span><span class="p">(</span><span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
  238. <span class="n">num_classes</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
  239. <span class="n">kernel_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
  240. <span class="n">ignore_index</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
  241. <span class="n">flatten_channels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
  242. <span class="sd">&quot;&quot;&quot;</span>
  243. <span class="sd"> Utils function to create edge feature maps from target.</span>
  244. <span class="sd"> :param target: Class labels long tensor, with shape [N, H, W]</span>
  245. <span class="sd"> :param num_classes: num of classes in datasets excluding ignore label, this is the output channels of the one hot</span>
  246. <span class="sd"> result.</span>
  247. <span class="sd"> :param kernel_size: kernel size of dilation erosion convolutions. The result edge widths depends on this argument as</span>
  248. <span class="sd"> follows: `edge_width = kernel - 1`</span>
  249. <span class="sd"> :param flatten_channels: Whether to apply logical or across channels dimension, if at least one pixel class is</span>
  250. <span class="sd"> considered as edge pixel flatten value is 1. If set as `False` the output tensor shape is [B, C, H, W], else</span>
  251. <span class="sd"> [B, 1, H, W]. Default is `True`.</span>
  252. <span class="sd"> :return: one_hot edge torch.Tensor.</span>
  253. <span class="sd"> &quot;&quot;&quot;</span>
  254. <span class="n">one_hot</span> <span class="o">=</span> <span class="n">to_one_hot</span><span class="p">(</span><span class="n">target</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">ignore_index</span><span class="o">=</span><span class="n">ignore_index</span><span class="p">)</span>
  255. <span class="k">return</span> <span class="n">one_hot_to_binary_edge</span><span class="p">(</span><span class="n">one_hot</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="n">kernel_size</span><span class="p">,</span> <span class="n">flatten_channels</span><span class="o">=</span><span class="n">flatten_channels</span><span class="p">)</span></div>
  256. </pre></div>
  257. </div>
  258. </div>
  259. <footer>
  260. <hr/>
  261. <div role="contentinfo">
  262. <p>&#169; Copyright 2021, SuperGradients team.</p>
  263. </div>
  264. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  265. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  266. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  267. </footer>
  268. </div>
  269. </div>
  270. </section>
  271. </div>
  272. <script>
  273. jQuery(function () {
  274. SphinxRtdTheme.Navigation.enable(true);
  275. });
  276. </script>
  277. </body>
  278. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.utils.sg_model_utils &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.utils.sg_model_utils</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.utils.sg_model_utils</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">os</span>
  84. <span class="kn">import</span> <span class="nn">sys</span>
  85. <span class="kn">import</span> <span class="nn">socket</span>
  86. <span class="kn">import</span> <span class="nn">time</span>
  87. <span class="kn">from</span> <span class="nn">dataclasses</span> <span class="kn">import</span> <span class="n">dataclass</span>
  88. <span class="kn">from</span> <span class="nn">multiprocessing</span> <span class="kn">import</span> <span class="n">Process</span>
  89. <span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
  90. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">Dict</span>
  91. <span class="kn">import</span> <span class="nn">random</span>
  92. <span class="kn">from</span> <span class="nn">treelib</span> <span class="kn">import</span> <span class="n">Tree</span>
  93. <span class="kn">from</span> <span class="nn">termcolor</span> <span class="kn">import</span> <span class="n">colored</span>
  94. <span class="kn">import</span> <span class="nn">torch</span>
  95. <span class="kn">from</span> <span class="nn">torch.utils.tensorboard</span> <span class="kn">import</span> <span class="n">SummaryWriter</span>
  96. <span class="kn">from</span> <span class="nn">super_gradients.training.exceptions.dataset_exceptions</span> <span class="kn">import</span> <span class="n">UnsupportedBatchItemsFormat</span>
  97. <span class="c1"># TODO: These utils should move to sg_model package as internal (private) helper functions</span>
  98. <span class="n">IS_BETTER_COLOR</span> <span class="o">=</span> <span class="p">{</span><span class="kc">True</span><span class="p">:</span> <span class="s2">&quot;green&quot;</span><span class="p">,</span> <span class="kc">False</span><span class="p">:</span> <span class="s2">&quot;red&quot;</span><span class="p">}</span>
  99. <span class="n">IS_GREATER_SYMBOLS</span> <span class="o">=</span> <span class="p">{</span><span class="kc">True</span><span class="p">:</span> <span class="s2">&quot;↗&quot;</span><span class="p">,</span> <span class="kc">False</span><span class="p">:</span> <span class="s2">&quot;↘&quot;</span><span class="p">}</span>
  100. <div class="viewcode-block" id="MonitoredValue"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.MonitoredValue">[docs]</a><span class="nd">@dataclass</span>
  101. <span class="k">class</span> <span class="nc">MonitoredValue</span><span class="p">:</span>
  102. <span class="sd">&quot;&quot;&quot;Store a value and some indicators relative to its past iterations.</span>
  103. <span class="sd"> The value can be a metric/loss, and the iteration can be epochs/batch.</span>
  104. <span class="sd"> &quot;&quot;&quot;</span>
  105. <span class="n">name</span><span class="p">:</span> <span class="nb">str</span>
  106. <span class="n">greater_is_better</span><span class="p">:</span> <span class="nb">bool</span>
  107. <span class="n">current</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="kc">None</span>
  108. <span class="n">previous</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="kc">None</span>
  109. <span class="n">best</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="kc">None</span>
  110. <span class="n">change_from_previous</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="kc">None</span>
  111. <span class="n">change_from_best</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="kc">None</span>
  112. <span class="nd">@property</span>
  113. <span class="k">def</span> <span class="nf">is_better_than_previous</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  114. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">change_from_best</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  115. <span class="k">return</span> <span class="kc">None</span>
  116. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span><span class="p">:</span>
  117. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">change_from_previous</span> <span class="o">&gt;=</span> <span class="mi">0</span>
  118. <span class="k">else</span><span class="p">:</span>
  119. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">change_from_previous</span> <span class="o">&lt;</span> <span class="mi">0</span>
  120. <span class="nd">@property</span>
  121. <span class="k">def</span> <span class="nf">is_best_value</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  122. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">change_from_best</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  123. <span class="k">return</span> <span class="kc">None</span>
  124. <span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span><span class="p">:</span>
  125. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">change_from_best</span> <span class="o">&gt;=</span> <span class="mi">0</span>
  126. <span class="k">else</span><span class="p">:</span>
  127. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">change_from_best</span> <span class="o">&lt;</span> <span class="mi">0</span></div>
  128. <div class="viewcode-block" id="update_monitored_value"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.update_monitored_value">[docs]</a><span class="k">def</span> <span class="nf">update_monitored_value</span><span class="p">(</span><span class="n">previous_monitored_value</span><span class="p">:</span> <span class="n">MonitoredValue</span><span class="p">,</span> <span class="n">new_value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">MonitoredValue</span><span class="p">:</span>
  129. <span class="sd">&quot;&quot;&quot;Update the given ValueToMonitor object (could be a loss or a metric) with the new value</span>
  130. <span class="sd"> :param previous_monitored_value: The stats about the value that is monitored throughout epochs.</span>
  131. <span class="sd"> :param new_value: The value of the current epoch that will be used to update previous_monitored_value</span>
  132. <span class="sd"> :return:</span>
  133. <span class="sd"> &quot;&quot;&quot;</span>
  134. <span class="n">previous_value</span><span class="p">,</span> <span class="n">previous_best_value</span> <span class="o">=</span> <span class="n">previous_monitored_value</span><span class="o">.</span><span class="n">current</span><span class="p">,</span> <span class="n">previous_monitored_value</span><span class="o">.</span><span class="n">best</span>
  135. <span class="n">name</span><span class="p">,</span> <span class="n">greater_is_better</span> <span class="o">=</span> <span class="n">previous_monitored_value</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="n">previous_monitored_value</span><span class="o">.</span><span class="n">greater_is_better</span>
  136. <span class="k">if</span> <span class="n">previous_best_value</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  137. <span class="n">previous_best_value</span> <span class="o">=</span> <span class="n">previous_value</span>
  138. <span class="k">elif</span> <span class="n">greater_is_better</span><span class="p">:</span>
  139. <span class="n">previous_best_value</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">previous_value</span><span class="p">,</span> <span class="n">previous_best_value</span><span class="p">)</span>
  140. <span class="k">else</span><span class="p">:</span>
  141. <span class="n">previous_best_value</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">previous_value</span><span class="p">,</span> <span class="n">previous_best_value</span><span class="p">)</span>
  142. <span class="k">if</span> <span class="n">previous_value</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  143. <span class="n">change_from_previous</span> <span class="o">=</span> <span class="kc">None</span>
  144. <span class="n">change_from_best</span> <span class="o">=</span> <span class="kc">None</span>
  145. <span class="k">else</span><span class="p">:</span>
  146. <span class="n">change_from_previous</span> <span class="o">=</span> <span class="n">new_value</span> <span class="o">-</span> <span class="n">previous_value</span>
  147. <span class="n">change_from_best</span> <span class="o">=</span> <span class="n">new_value</span> <span class="o">-</span> <span class="n">previous_best_value</span>
  148. <span class="k">return</span> <span class="n">MonitoredValue</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span> <span class="n">current</span><span class="o">=</span><span class="n">new_value</span><span class="p">,</span> <span class="n">previous</span><span class="o">=</span><span class="n">previous_value</span><span class="p">,</span> <span class="n">best</span><span class="o">=</span><span class="n">previous_best_value</span><span class="p">,</span>
  149. <span class="n">change_from_previous</span><span class="o">=</span><span class="n">change_from_previous</span><span class="p">,</span> <span class="n">change_from_best</span><span class="o">=</span><span class="n">change_from_best</span><span class="p">,</span>
  150. <span class="n">greater_is_better</span><span class="o">=</span><span class="n">greater_is_better</span><span class="p">)</span></div>
  151. <div class="viewcode-block" id="update_monitored_values_dict"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.update_monitored_values_dict">[docs]</a><span class="k">def</span> <span class="nf">update_monitored_values_dict</span><span class="p">(</span><span class="n">monitored_values_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">MonitoredValue</span><span class="p">],</span>
  152. <span class="n">new_values_dict</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="nb">float</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">MonitoredValue</span><span class="p">]:</span>
  153. <span class="sd">&quot;&quot;&quot;Update the given ValueToMonitor object (could be a loss or a metric) with the new value</span>
  154. <span class="sd"> :param monitored_values_dict: Dict mapping value names to their stats throughout epochs.</span>
  155. <span class="sd"> :param new_values_dict: Dict mapping value names to their new (i.e. current epoch) value.</span>
  156. <span class="sd"> :return: Updated monitored_values_dict</span>
  157. <span class="sd"> &quot;&quot;&quot;</span>
  158. <span class="k">for</span> <span class="n">monitored_value_name</span> <span class="ow">in</span> <span class="n">monitored_values_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
  159. <span class="n">monitored_values_dict</span><span class="p">[</span><span class="n">monitored_value_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">update_monitored_value</span><span class="p">(</span>
  160. <span class="n">new_value</span><span class="o">=</span><span class="n">new_values_dict</span><span class="p">[</span><span class="n">monitored_value_name</span><span class="p">],</span>
  161. <span class="n">previous_monitored_value</span><span class="o">=</span><span class="n">monitored_values_dict</span><span class="p">[</span><span class="n">monitored_value_name</span><span class="p">],</span>
  162. <span class="p">)</span>
  163. <span class="k">return</span> <span class="n">monitored_values_dict</span></div>
  164. <div class="viewcode-block" id="display_epoch_summary"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.display_epoch_summary">[docs]</a><span class="k">def</span> <span class="nf">display_epoch_summary</span><span class="p">(</span><span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">n_digits</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
  165. <span class="n">train_monitored_values</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">MonitoredValue</span><span class="p">],</span>
  166. <span class="n">valid_monitored_values</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">MonitoredValue</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
  167. <span class="sd">&quot;&quot;&quot;Display a summary of loss/metric of interest, for a given epoch.</span>
  168. <span class="sd"> :param epoch: the number of epoch.</span>
  169. <span class="sd"> :param n_digits: number of digits to display on screen for float values</span>
  170. <span class="sd"> :param train_monitored_values: mapping of loss/metric with their stats that will be displayed</span>
  171. <span class="sd"> :param valid_monitored_values: mapping of loss/metric with their stats that will be displayed</span>
  172. <span class="sd"> :return:</span>
  173. <span class="sd"> &quot;&quot;&quot;</span>
  174. <span class="k">def</span> <span class="nf">_format_to_str</span><span class="p">(</span><span class="n">val</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
  175. <span class="k">return</span> <span class="nb">str</span><span class="p">(</span><span class="nb">round</span><span class="p">(</span><span class="n">val</span><span class="p">,</span> <span class="n">n_digits</span><span class="p">))</span>
  176. <span class="k">def</span> <span class="nf">_generate_tree</span><span class="p">(</span><span class="n">value_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">monitored_value</span><span class="p">:</span> <span class="n">MonitoredValue</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tree</span><span class="p">:</span>
  177. <span class="sd">&quot;&quot;&quot;Generate a tree that represents the stats of a given loss/metric.&quot;&quot;&quot;</span>
  178. <span class="n">current</span> <span class="o">=</span> <span class="n">_format_to_str</span><span class="p">(</span><span class="n">monitored_value</span><span class="o">.</span><span class="n">current</span><span class="p">)</span>
  179. <span class="n">root_id</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="nb">hash</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">value_name</span><span class="si">}</span><span class="s2"> = </span><span class="si">{</span><span class="n">current</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">))</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">())</span>
  180. <span class="n">tree</span> <span class="o">=</span> <span class="n">Tree</span><span class="p">()</span>
  181. <span class="n">tree</span><span class="o">.</span><span class="n">create_node</span><span class="p">(</span><span class="n">tag</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">value_name</span><span class="o">.</span><span class="n">capitalize</span><span class="p">()</span><span class="si">}</span><span class="s2"> = </span><span class="si">{</span><span class="n">current</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="n">identifier</span><span class="o">=</span><span class="n">root_id</span><span class="p">)</span>
  182. <span class="k">if</span> <span class="n">monitored_value</span><span class="o">.</span><span class="n">previous</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  183. <span class="n">previous</span> <span class="o">=</span> <span class="n">_format_to_str</span><span class="p">(</span><span class="n">monitored_value</span><span class="o">.</span><span class="n">previous</span><span class="p">)</span>
  184. <span class="n">best</span> <span class="o">=</span> <span class="n">_format_to_str</span><span class="p">(</span><span class="n">monitored_value</span><span class="o">.</span><span class="n">best</span><span class="p">)</span>
  185. <span class="n">change_from_previous</span> <span class="o">=</span> <span class="n">_format_to_str</span><span class="p">(</span><span class="n">monitored_value</span><span class="o">.</span><span class="n">change_from_previous</span><span class="p">)</span>
  186. <span class="n">change_from_best</span> <span class="o">=</span> <span class="n">_format_to_str</span><span class="p">(</span><span class="n">monitored_value</span><span class="o">.</span><span class="n">change_from_best</span><span class="p">)</span>
  187. <span class="n">diff_with_prev_colored</span> <span class="o">=</span> <span class="n">colored</span><span class="p">(</span>
  188. <span class="n">text</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">IS_GREATER_SYMBOLS</span><span class="p">[</span><span class="n">monitored_value</span><span class="o">.</span><span class="n">change_from_previous</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">change_from_previous</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span>
  189. <span class="n">color</span><span class="o">=</span><span class="n">IS_BETTER_COLOR</span><span class="p">[</span><span class="n">monitored_value</span><span class="o">.</span><span class="n">is_better_than_previous</span><span class="p">]</span>
  190. <span class="p">)</span>
  191. <span class="n">diff_with_best_colored</span> <span class="o">=</span> <span class="n">colored</span><span class="p">(</span>
  192. <span class="n">text</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">IS_GREATER_SYMBOLS</span><span class="p">[</span><span class="n">monitored_value</span><span class="o">.</span><span class="n">change_from_best</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2"> </span><span class="si">{</span><span class="n">change_from_best</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span>
  193. <span class="n">color</span><span class="o">=</span><span class="n">IS_BETTER_COLOR</span><span class="p">[</span><span class="n">monitored_value</span><span class="o">.</span><span class="n">is_best_value</span><span class="p">]</span>
  194. <span class="p">)</span>
  195. <span class="n">tree</span><span class="o">.</span><span class="n">create_node</span><span class="p">(</span>
  196. <span class="n">tag</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;Epoch N-1 = </span><span class="si">{</span><span class="n">previous</span><span class="si">:</span><span class="s2">6</span><span class="si">}</span><span class="s2"> (</span><span class="si">{</span><span class="n">diff_with_prev_colored</span><span class="si">:</span><span class="s2">8</span><span class="si">}</span><span class="s2">)&quot;</span><span class="p">,</span>
  197. <span class="n">identifier</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;0_previous_</span><span class="si">{</span><span class="n">root_id</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span>
  198. <span class="n">parent</span><span class="o">=</span><span class="n">root_id</span>
  199. <span class="p">)</span>
  200. <span class="n">tree</span><span class="o">.</span><span class="n">create_node</span><span class="p">(</span>
  201. <span class="n">tag</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;Best until now = </span><span class="si">{</span><span class="n">best</span><span class="si">:</span><span class="s2">6</span><span class="si">}</span><span class="s2"> (</span><span class="si">{</span><span class="n">diff_with_best_colored</span><span class="si">:</span><span class="s2">8</span><span class="si">}</span><span class="s2">)&quot;</span><span class="p">,</span>
  202. <span class="n">identifier</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;1_best_</span><span class="si">{</span><span class="n">root_id</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span>
  203. <span class="n">parent</span><span class="o">=</span><span class="n">root_id</span>
  204. <span class="p">)</span>
  205. <span class="k">return</span> <span class="n">tree</span>
  206. <span class="n">train_tree</span> <span class="o">=</span> <span class="n">Tree</span><span class="p">()</span>
  207. <span class="n">train_tree</span><span class="o">.</span><span class="n">create_node</span><span class="p">(</span><span class="s2">&quot;Training&quot;</span><span class="p">,</span> <span class="s2">&quot;Training&quot;</span><span class="p">)</span>
  208. <span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">train_monitored_values</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
  209. <span class="n">train_tree</span><span class="o">.</span><span class="n">paste</span><span class="p">(</span><span class="s1">&#39;Training&#39;</span><span class="p">,</span> <span class="n">new_tree</span><span class="o">=</span><span class="n">_generate_tree</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">monitored_value</span><span class="o">=</span><span class="n">value</span><span class="p">))</span>
  210. <span class="n">valid_tree</span> <span class="o">=</span> <span class="n">Tree</span><span class="p">()</span>
  211. <span class="n">valid_tree</span><span class="o">.</span><span class="n">create_node</span><span class="p">(</span><span class="s2">&quot;Validation&quot;</span><span class="p">,</span> <span class="s2">&quot;Validation&quot;</span><span class="p">)</span>
  212. <span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">valid_monitored_values</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
  213. <span class="n">valid_tree</span><span class="o">.</span><span class="n">paste</span><span class="p">(</span><span class="s1">&#39;Validation&#39;</span><span class="p">,</span> <span class="n">new_tree</span><span class="o">=</span><span class="n">_generate_tree</span><span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">monitored_value</span><span class="o">=</span><span class="n">value</span><span class="p">))</span>
  214. <span class="n">summary_tree</span> <span class="o">=</span> <span class="n">Tree</span><span class="p">()</span>
  215. <span class="n">summary_tree</span><span class="o">.</span><span class="n">create_node</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;SUMMARY OF EPOCH </span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="s2">&quot;Summary&quot;</span><span class="p">)</span>
  216. <span class="n">summary_tree</span><span class="o">.</span><span class="n">paste</span><span class="p">(</span><span class="s2">&quot;Summary&quot;</span><span class="p">,</span> <span class="n">train_tree</span><span class="p">)</span>
  217. <span class="n">summary_tree</span><span class="o">.</span><span class="n">paste</span><span class="p">(</span><span class="s2">&quot;Summary&quot;</span><span class="p">,</span> <span class="n">valid_tree</span><span class="p">)</span>
  218. <span class="n">summary_tree</span><span class="o">.</span><span class="n">show</span><span class="p">()</span></div>
  219. <div class="viewcode-block" id="try_port"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.try_port">[docs]</a><span class="k">def</span> <span class="nf">try_port</span><span class="p">(</span><span class="n">port</span><span class="p">):</span>
  220. <span class="sd">&quot;&quot;&quot;</span>
  221. <span class="sd"> try_port - Helper method for tensorboard port binding</span>
  222. <span class="sd"> :param port:</span>
  223. <span class="sd"> :return:</span>
  224. <span class="sd"> &quot;&quot;&quot;</span>
  225. <span class="n">sock</span> <span class="o">=</span> <span class="n">socket</span><span class="o">.</span><span class="n">socket</span><span class="p">(</span><span class="n">socket</span><span class="o">.</span><span class="n">AF_INET</span><span class="p">,</span> <span class="n">socket</span><span class="o">.</span><span class="n">SOCK_STREAM</span><span class="p">)</span>
  226. <span class="n">is_port_available</span> <span class="o">=</span> <span class="kc">False</span>
  227. <span class="k">try</span><span class="p">:</span>
  228. <span class="n">sock</span><span class="o">.</span><span class="n">bind</span><span class="p">((</span><span class="s2">&quot;localhost&quot;</span><span class="p">,</span> <span class="n">port</span><span class="p">))</span>
  229. <span class="n">is_port_available</span> <span class="o">=</span> <span class="kc">True</span>
  230. <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>
  231. <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Port &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">port</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39; is in use&#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">ex</span><span class="p">))</span>
  232. <span class="n">sock</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
  233. <span class="k">return</span> <span class="n">is_port_available</span></div>
  234. <div class="viewcode-block" id="launch_tensorboard_process"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.launch_tensorboard_process">[docs]</a><span class="k">def</span> <span class="nf">launch_tensorboard_process</span><span class="p">(</span><span class="n">checkpoints_dir_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">sleep_postpone</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span> <span class="n">port</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Process</span><span class="p">,</span> <span class="nb">int</span><span class="p">]:</span>
  235. <span class="sd">&quot;&quot;&quot;</span>
  236. <span class="sd"> launch_tensorboard_process - Default behavior is to scan all free ports from 6006-6016 and try using them</span>
  237. <span class="sd"> unless port is defined by the user</span>
  238. <span class="sd"> :param checkpoints_dir_path:</span>
  239. <span class="sd"> :param sleep_postpone:</span>
  240. <span class="sd"> :param port:</span>
  241. <span class="sd"> :return: tuple of tb process, port</span>
  242. <span class="sd"> &quot;&quot;&quot;</span>
  243. <span class="n">logdir_path</span> <span class="o">=</span> <span class="nb">str</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">checkpoints_dir_path</span><span class="p">)</span><span class="o">.</span><span class="n">parent</span><span class="o">.</span><span class="n">absolute</span><span class="p">())</span>
  244. <span class="n">tb_cmd</span> <span class="o">=</span> <span class="s1">&#39;tensorboard --logdir=&#39;</span> <span class="o">+</span> <span class="n">logdir_path</span> <span class="o">+</span> <span class="s1">&#39; --bind_all&#39;</span>
  245. <span class="k">if</span> <span class="n">port</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  246. <span class="n">tb_ports</span> <span class="o">=</span> <span class="p">[</span><span class="n">port</span><span class="p">]</span>
  247. <span class="k">else</span><span class="p">:</span>
  248. <span class="n">tb_ports</span> <span class="o">=</span> <span class="nb">range</span><span class="p">(</span><span class="mi">6006</span><span class="p">,</span> <span class="mi">6016</span><span class="p">)</span>
  249. <span class="k">for</span> <span class="n">tb_port</span> <span class="ow">in</span> <span class="n">tb_ports</span><span class="p">:</span>
  250. <span class="k">if</span> <span class="ow">not</span> <span class="n">try_port</span><span class="p">(</span><span class="n">tb_port</span><span class="p">):</span>
  251. <span class="k">continue</span>
  252. <span class="k">else</span><span class="p">:</span>
  253. <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Starting Tensor-Board process on port: &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">tb_port</span><span class="p">))</span>
  254. <span class="n">tensor_board_process</span> <span class="o">=</span> <span class="n">Process</span><span class="p">(</span><span class="n">target</span><span class="o">=</span><span class="n">os</span><span class="o">.</span><span class="n">system</span><span class="p">,</span> <span class="n">args</span><span class="o">=</span><span class="p">([</span><span class="n">tb_cmd</span> <span class="o">+</span> <span class="s1">&#39; --port=&#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">tb_port</span><span class="p">)]))</span>
  255. <span class="n">tensor_board_process</span><span class="o">.</span><span class="n">daemon</span> <span class="o">=</span> <span class="kc">True</span>
  256. <span class="n">tensor_board_process</span><span class="o">.</span><span class="n">start</span><span class="p">()</span>
  257. <span class="c1"># LET THE TENSORBOARD PROCESS START</span>
  258. <span class="k">if</span> <span class="n">sleep_postpone</span><span class="p">:</span>
  259. <span class="n">time</span><span class="o">.</span><span class="n">sleep</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
  260. <span class="k">return</span> <span class="n">tensor_board_process</span><span class="p">,</span> <span class="n">tb_port</span>
  261. <span class="c1"># RETURNING IRRELEVANT VALUES</span>
  262. <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Failed to initialize Tensor-Board process on port: &#39;</span> <span class="o">+</span> <span class="s1">&#39;, &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">tb_ports</span><span class="p">)))</span>
  263. <span class="k">return</span> <span class="kc">None</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span></div>
  264. <div class="viewcode-block" id="init_summary_writer"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.init_summary_writer">[docs]</a><span class="k">def</span> <span class="nf">init_summary_writer</span><span class="p">(</span><span class="n">tb_dir</span><span class="p">,</span> <span class="n">checkpoint_loaded</span><span class="p">,</span> <span class="n">user_prompt</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
  265. <span class="sd">&quot;&quot;&quot;Remove previous tensorboard files from directory and launch a tensor board process&quot;&quot;&quot;</span>
  266. <span class="c1"># If the training is from scratch, Walk through destination folder and delete existing tensorboard logs</span>
  267. <span class="n">user</span> <span class="o">=</span> <span class="s1">&#39;&#39;</span>
  268. <span class="k">if</span> <span class="ow">not</span> <span class="n">checkpoint_loaded</span><span class="p">:</span>
  269. <span class="k">for</span> <span class="n">filename</span> <span class="ow">in</span> <span class="n">os</span><span class="o">.</span><span class="n">listdir</span><span class="p">(</span><span class="n">tb_dir</span><span class="p">):</span>
  270. <span class="k">if</span> <span class="s1">&#39;events&#39;</span> <span class="ow">in</span> <span class="n">filename</span><span class="p">:</span>
  271. <span class="k">if</span> <span class="ow">not</span> <span class="n">user_prompt</span><span class="p">:</span>
  272. <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;&quot;</span><span class="si">{}</span><span class="s1">&quot; will not be deleted&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">filename</span><span class="p">))</span>
  273. <span class="k">continue</span>
  274. <span class="k">while</span> <span class="kc">True</span><span class="p">:</span>
  275. <span class="c1"># Verify with user before deleting old tensorboard files</span>
  276. <span class="n">user</span> <span class="o">=</span> <span class="nb">input</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">OLDER TENSORBOARD FILES EXISTS IN EXPERIMENT FOLDER:</span><span class="se">\n</span><span class="s1">&quot;</span><span class="si">{}</span><span class="s1">&quot;</span><span class="se">\n</span><span class="s1">&#39;</span>
  277. <span class="s1">&#39;DO YOU WANT TO DELETE THEM? [y/n]&#39;</span>
  278. <span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">filename</span><span class="p">))</span> <span class="k">if</span> <span class="p">(</span><span class="n">user</span> <span class="o">!=</span> <span class="s1">&#39;n&#39;</span> <span class="ow">or</span> <span class="n">user</span> <span class="o">!=</span> <span class="s1">&#39;y&#39;</span><span class="p">)</span> <span class="k">else</span> <span class="n">user</span>
  279. <span class="k">if</span> <span class="n">user</span> <span class="o">==</span> <span class="s1">&#39;y&#39;</span><span class="p">:</span>
  280. <span class="n">os</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="s1">&#39;</span><span class="si">{}</span><span class="s1">/</span><span class="si">{}</span><span class="s1">&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">tb_dir</span><span class="p">,</span> <span class="n">filename</span><span class="p">))</span>
  281. <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;DELETED: </span><span class="si">{}</span><span class="s1">!&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">filename</span><span class="p">))</span>
  282. <span class="k">break</span>
  283. <span class="k">elif</span> <span class="n">user</span> <span class="o">==</span> <span class="s1">&#39;n&#39;</span><span class="p">:</span>
  284. <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;&quot;</span><span class="si">{}</span><span class="s1">&quot; will not be deleted&#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">filename</span><span class="p">))</span>
  285. <span class="k">break</span>
  286. <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Unknown answer...&#39;</span><span class="p">)</span>
  287. <span class="c1"># Launch a tensorboard process</span>
  288. <span class="k">return</span> <span class="n">SummaryWriter</span><span class="p">(</span><span class="n">tb_dir</span><span class="p">)</span></div>
  289. <div class="viewcode-block" id="add_log_to_file"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.add_log_to_file">[docs]</a><span class="k">def</span> <span class="nf">add_log_to_file</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="n">results_titles_list</span><span class="p">,</span> <span class="n">results_values_list</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">max_epochs</span><span class="p">):</span>
  290. <span class="sd">&quot;&quot;&quot;Add a message to the log file&quot;&quot;&quot;</span>
  291. <span class="c1"># -Note: opening and closing the file every time is in-efficient. It is done for experimental purposes</span>
  292. <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="s1">&#39;a&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
  293. <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="s1">&#39;</span><span class="se">\n</span><span class="s1">Epoch (</span><span class="si">%d</span><span class="s1">/</span><span class="si">%d</span><span class="s1">) - &#39;</span> <span class="o">%</span> <span class="p">(</span><span class="n">epoch</span><span class="p">,</span> <span class="n">max_epochs</span><span class="p">))</span>
  294. <span class="k">for</span> <span class="n">result_title</span><span class="p">,</span> <span class="n">result_value</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">results_titles_list</span><span class="p">,</span> <span class="n">results_values_list</span><span class="p">):</span>
  295. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">result_value</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
  296. <span class="n">result_value</span> <span class="o">=</span> <span class="n">result_value</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
  297. <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="n">result_title</span> <span class="o">+</span> <span class="s1">&#39;: &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">result_value</span><span class="p">)</span> <span class="o">+</span> <span class="s1">&#39;</span><span class="se">\t</span><span class="s1">&#39;</span><span class="p">)</span></div>
  298. <div class="viewcode-block" id="write_training_results"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.write_training_results">[docs]</a><span class="k">def</span> <span class="nf">write_training_results</span><span class="p">(</span><span class="n">writer</span><span class="p">,</span> <span class="n">results_titles_list</span><span class="p">,</span> <span class="n">results_values_list</span><span class="p">,</span> <span class="n">epoch</span><span class="p">):</span>
  299. <span class="sd">&quot;&quot;&quot;Stores the training and validation loss and accuracy for current epoch in a tensorboard file&quot;&quot;&quot;</span>
  300. <span class="k">for</span> <span class="n">res_key</span><span class="p">,</span> <span class="n">res_val</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">results_titles_list</span><span class="p">,</span> <span class="n">results_values_list</span><span class="p">):</span>
  301. <span class="c1"># USE ONLY LOWER-CASE LETTERS AND REPLACE SPACES WITH &#39;_&#39; TO AVOID MANY TITLES FOR THE SAME KEY</span>
  302. <span class="n">corrected_res_key</span> <span class="o">=</span> <span class="n">res_key</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s1">&#39; &#39;</span><span class="p">,</span> <span class="s1">&#39;_&#39;</span><span class="p">)</span>
  303. <span class="n">writer</span><span class="o">.</span><span class="n">add_scalar</span><span class="p">(</span><span class="n">corrected_res_key</span><span class="p">,</span> <span class="n">res_val</span><span class="p">,</span> <span class="n">epoch</span><span class="p">)</span>
  304. <span class="n">writer</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span></div>
  305. <div class="viewcode-block" id="write_hpms"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.write_hpms">[docs]</a><span class="k">def</span> <span class="nf">write_hpms</span><span class="p">(</span><span class="n">writer</span><span class="p">,</span> <span class="n">hpmstructs</span><span class="o">=</span><span class="p">[],</span> <span class="n">special_conf</span><span class="o">=</span><span class="p">{}):</span>
  306. <span class="sd">&quot;&quot;&quot;Stores the training and dataset hyper params in the tensorboard file&quot;&quot;&quot;</span>
  307. <span class="n">hpm_string</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span>
  308. <span class="k">for</span> <span class="n">hpm</span> <span class="ow">in</span> <span class="n">hpmstructs</span><span class="p">:</span>
  309. <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="n">hpm</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
  310. <span class="n">hpm_string</span> <span class="o">+=</span> <span class="s1">&#39;</span><span class="si">{}</span><span class="s1">: </span><span class="si">{}</span><span class="s1"> </span><span class="se">\n</span><span class="s1"> &#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">val</span><span class="p">)</span>
  311. <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="n">special_conf</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
  312. <span class="n">hpm_string</span> <span class="o">+=</span> <span class="s1">&#39;</span><span class="si">{}</span><span class="s1">: </span><span class="si">{}</span><span class="s1"> </span><span class="se">\n</span><span class="s1"> &#39;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">val</span><span class="p">)</span>
  313. <span class="n">writer</span><span class="o">.</span><span class="n">add_text</span><span class="p">(</span><span class="s2">&quot;Hyper_parameters&quot;</span><span class="p">,</span> <span class="n">hpm_string</span><span class="p">)</span>
  314. <span class="n">writer</span><span class="o">.</span><span class="n">flush</span><span class="p">()</span></div>
  315. <span class="c1"># TODO: This should probably move into datasets/datasets_utils.py?</span>
  316. <div class="viewcode-block" id="unpack_batch_items"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.unpack_batch_items">[docs]</a><span class="k">def</span> <span class="nf">unpack_batch_items</span><span class="p">(</span><span class="n">batch_items</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">tuple</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span>
  317. <span class="sd">&quot;&quot;&quot;</span>
  318. <span class="sd"> Adds support for unpacking batch items in train/validation loop.</span>
  319. <span class="sd"> @param batch_items: (Union[tuple, torch.Tensor]) returned by the data loader, which is expected to be in one of</span>
  320. <span class="sd"> the following formats:</span>
  321. <span class="sd"> 1. torch.Tensor or tuple, s.t inputs = batch_items[0], targets = batch_items[1] and len(batch_items) = 2</span>
  322. <span class="sd"> 2. tuple: (inputs, targets, additional_batch_items)</span>
  323. <span class="sd"> where inputs are fed to the network, targets are their corresponding labels and additional_batch_items is a</span>
  324. <span class="sd"> dictionary (format {additional_batch_item_i_name: additional_batch_item_i ...}) which can be accessed through</span>
  325. <span class="sd"> the phase context under the attribute additional_batch_item_i_name, using a phase callback.</span>
  326. <span class="sd"> @return: inputs, target, additional_batch_items</span>
  327. <span class="sd"> &quot;&quot;&quot;</span>
  328. <span class="n">additional_batch_items</span> <span class="o">=</span> <span class="p">{}</span>
  329. <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch_items</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>
  330. <span class="n">inputs</span><span class="p">,</span> <span class="n">target</span> <span class="o">=</span> <span class="n">batch_items</span>
  331. <span class="k">elif</span> <span class="nb">len</span><span class="p">(</span><span class="n">batch_items</span><span class="p">)</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span>
  332. <span class="n">inputs</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">additional_batch_items</span> <span class="o">=</span> <span class="n">batch_items</span>
  333. <span class="k">else</span><span class="p">:</span>
  334. <span class="k">raise</span> <span class="n">UnsupportedBatchItemsFormat</span><span class="p">()</span>
  335. <span class="k">return</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">target</span><span class="p">,</span> <span class="n">additional_batch_items</span></div>
  336. <div class="viewcode-block" id="log_uncaught_exceptions"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.sg_model_utils.log_uncaught_exceptions">[docs]</a><span class="k">def</span> <span class="nf">log_uncaught_exceptions</span><span class="p">(</span><span class="n">logger</span><span class="p">):</span>
  337. <span class="sd">&quot;&quot;&quot;</span>
  338. <span class="sd"> Makes logger log uncaught exceptions</span>
  339. <span class="sd"> @param logger: logging.Logger</span>
  340. <span class="sd"> @return: None</span>
  341. <span class="sd"> &quot;&quot;&quot;</span>
  342. <span class="k">def</span> <span class="nf">handle_exception</span><span class="p">(</span><span class="n">exc_type</span><span class="p">,</span> <span class="n">exc_value</span><span class="p">,</span> <span class="n">exc_traceback</span><span class="p">):</span>
  343. <span class="k">if</span> <span class="nb">issubclass</span><span class="p">(</span><span class="n">exc_type</span><span class="p">,</span> <span class="ne">KeyboardInterrupt</span><span class="p">):</span>
  344. <span class="n">sys</span><span class="o">.</span><span class="n">__excepthook__</span><span class="p">(</span><span class="n">exc_type</span><span class="p">,</span> <span class="n">exc_value</span><span class="p">,</span> <span class="n">exc_traceback</span><span class="p">)</span>
  345. <span class="k">return</span>
  346. <span class="n">logger</span><span class="o">.</span><span class="n">error</span><span class="p">(</span><span class="s2">&quot;Uncaught exception&quot;</span><span class="p">,</span> <span class="n">exc_info</span><span class="o">=</span><span class="p">(</span><span class="n">exc_type</span><span class="p">,</span> <span class="n">exc_value</span><span class="p">,</span> <span class="n">exc_traceback</span><span class="p">))</span>
  347. <span class="n">sys</span><span class="o">.</span><span class="n">excepthook</span> <span class="o">=</span> <span class="n">handle_exception</span></div>
  348. </pre></div>
  349. </div>
  350. </div>
  351. <footer>
  352. <hr/>
  353. <div role="contentinfo">
  354. <p>&#169; Copyright 2021, SuperGradients team.</p>
  355. </div>
  356. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  357. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  358. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  359. </footer>
  360. </div>
  361. </div>
  362. </section>
  363. </div>
  364. <script>
  365. jQuery(function () {
  366. SphinxRtdTheme.Navigation.enable(true);
  367. });
  368. </script>
  369. </body>
  370. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.utils.ssd_utils &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.utils.ssd_utils</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.utils.ssd_utils</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">itertools</span>
  84. <span class="kn">from</span> <span class="nn">math</span> <span class="kn">import</span> <span class="n">sqrt</span>
  85. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span>
  86. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  87. <span class="kn">import</span> <span class="nn">torch</span>
  88. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">non_max_suppression</span><span class="p">,</span> <span class="n">NMS_Type</span><span class="p">,</span> \
  89. <span class="n">matrix_non_max_suppression</span><span class="p">,</span> <span class="n">DetectionPostPredictionCallback</span>
  90. <div class="viewcode-block" id="DefaultBoxes"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.ssd_utils.DefaultBoxes">[docs]</a><span class="k">class</span> <span class="nc">DefaultBoxes</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span>
  91. <span class="sd">&quot;&quot;&quot;</span>
  92. <span class="sd"> Default Boxes, (aka: anchor boxes or priors boxes) used by SSD model</span>
  93. <span class="sd"> &quot;&quot;&quot;</span>
  94. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">fig_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">feat_size</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">scales</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">aspect_ratios</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]],</span>
  95. <span class="n">scale_xy</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">scale_wh</span><span class="o">=</span><span class="mf">0.2</span><span class="p">):</span>
  96. <span class="sd">&quot;&quot;&quot;</span>
  97. <span class="sd"> For each feature map i (each predicting level, grids) the anchors (a.k.a. default boxes) will be:</span>
  98. <span class="sd"> [</span>
  99. <span class="sd"> [s, s], [sqrt(s * s_next), sqrt(s * s_next)],</span>
  100. <span class="sd"> [s * sqrt(alpha1), s / sqrt(alpha1)], [s / sqrt(alpha1), s * sqrt(alpha1)],</span>
  101. <span class="sd"> ...</span>
  102. <span class="sd"> [s * sqrt(alphaN), s / sqrt(alphaN)], [s / sqrt(alphaN), s * sqrt(alphaN)]</span>
  103. <span class="sd"> ] / fig_size</span>
  104. <span class="sd"> where:</span>
  105. <span class="sd"> * s = scale[i] - this level&#39;s scale</span>
  106. <span class="sd"> * s_next = scale[i + 1] - next level&#39;s scale</span>
  107. <span class="sd"> * alpha1, ... alphaN - this level&#39;s alphas, e.g. [2, 3]</span>
  108. <span class="sd"> * fig_size - input image resolution</span>
  109. <span class="sd"> Because of division by image resolution, the anchors will be in image coordinates normalized to [0, 1]</span>
  110. <span class="sd"> :param fig_size: input image resolution</span>
  111. <span class="sd"> :param feat_size: resolution of all feature maps with predictions (grids)</span>
  112. <span class="sd"> :param scales: anchor sizes in pixels for each feature level;</span>
  113. <span class="sd"> one value per level will be used to generate anchors based on the formula above</span>
  114. <span class="sd"> :param aspect_ratios: lists of alpha values for each feature map</span>
  115. <span class="sd"> :param scale_xy: predicted boxes will be with a factor scale_xy</span>
  116. <span class="sd"> so will be multiplied by scale_xy during post-prediction processing;</span>
  117. <span class="sd"> e.g. scale 0.1 means that prediction will be 10 times bigger</span>
  118. <span class="sd"> (improves predictions quality)</span>
  119. <span class="sd"> :param scale_wh: same logic as in scale_xy, but for width and height.</span>
  120. <span class="sd"> &quot;&quot;&quot;</span>
  121. <span class="bp">self</span><span class="o">.</span><span class="n">feat_size</span> <span class="o">=</span> <span class="n">feat_size</span>
  122. <span class="bp">self</span><span class="o">.</span><span class="n">fig_size</span> <span class="o">=</span> <span class="n">fig_size</span>
  123. <span class="bp">self</span><span class="o">.</span><span class="n">scale_xy_</span> <span class="o">=</span> <span class="n">scale_xy</span>
  124. <span class="bp">self</span><span class="o">.</span><span class="n">scale_wh_</span> <span class="o">=</span> <span class="n">scale_wh</span>
  125. <span class="c1"># According to https://github.com/weiliu89/caffe</span>
  126. <span class="c1"># Calculation method slightly different from paper</span>
  127. <span class="bp">self</span><span class="o">.</span><span class="n">scales</span> <span class="o">=</span> <span class="n">scales</span>
  128. <span class="bp">self</span><span class="o">.</span><span class="n">aspect_ratios</span> <span class="o">=</span> <span class="n">aspect_ratios</span>
  129. <span class="bp">self</span><span class="o">.</span><span class="n">default_boxes</span> <span class="o">=</span> <span class="p">[]</span>
  130. <span class="bp">self</span><span class="o">.</span><span class="n">num_anchors</span> <span class="o">=</span> <span class="p">[]</span>
  131. <span class="c1"># size of feature and number of feature</span>
  132. <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">sfeat</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">feat_size</span><span class="p">):</span>
  133. <span class="n">sk1</span> <span class="o">=</span> <span class="n">scales</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
  134. <span class="n">sk2</span> <span class="o">=</span> <span class="n">scales</span><span class="p">[</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span>
  135. <span class="n">sk3</span> <span class="o">=</span> <span class="n">sqrt</span><span class="p">(</span><span class="n">sk1</span> <span class="o">*</span> <span class="n">sk2</span><span class="p">)</span>
  136. <span class="n">all_sizes</span> <span class="o">=</span> <span class="p">[(</span><span class="n">sk1</span><span class="p">,</span> <span class="n">sk1</span><span class="p">),</span> <span class="p">(</span><span class="n">sk3</span><span class="p">,</span> <span class="n">sk3</span><span class="p">)]</span>
  137. <span class="k">for</span> <span class="n">alpha</span> <span class="ow">in</span> <span class="n">aspect_ratios</span><span class="p">[</span><span class="n">idx</span><span class="p">]:</span>
  138. <span class="n">w</span><span class="p">,</span> <span class="n">h</span> <span class="o">=</span> <span class="n">sk1</span> <span class="o">*</span> <span class="n">sqrt</span><span class="p">(</span><span class="n">alpha</span><span class="p">),</span> <span class="n">sk1</span> <span class="o">/</span> <span class="n">sqrt</span><span class="p">(</span><span class="n">alpha</span><span class="p">)</span>
  139. <span class="n">all_sizes</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">))</span>
  140. <span class="n">all_sizes</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">h</span><span class="p">,</span> <span class="n">w</span><span class="p">))</span>
  141. <span class="n">all_sizes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">all_sizes</span><span class="p">)</span> <span class="o">/</span> <span class="n">fig_size</span>
  142. <span class="bp">self</span><span class="o">.</span><span class="n">num_anchors</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">all_sizes</span><span class="p">))</span>
  143. <span class="k">for</span> <span class="n">w</span><span class="p">,</span> <span class="n">h</span> <span class="ow">in</span> <span class="n">all_sizes</span><span class="p">:</span>
  144. <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span> <span class="ow">in</span> <span class="n">itertools</span><span class="o">.</span><span class="n">product</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">sfeat</span><span class="p">),</span> <span class="n">repeat</span><span class="o">=</span><span class="mi">2</span><span class="p">):</span>
  145. <span class="n">cx</span><span class="p">,</span> <span class="n">cy</span> <span class="o">=</span> <span class="p">(</span><span class="n">j</span> <span class="o">+</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">/</span> <span class="n">sfeat</span><span class="p">,</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">/</span> <span class="n">sfeat</span>
  146. <span class="bp">self</span><span class="o">.</span><span class="n">default_boxes</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">cx</span><span class="p">,</span> <span class="n">cy</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">))</span>
  147. <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">default_boxes</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">float</span><span class="p">)</span>
  148. <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="o">.</span><span class="n">clamp_</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
  149. <span class="c1"># For IoU calculation</span>
  150. <span class="bp">self</span><span class="o">.</span><span class="n">dboxes_xyxy</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
  151. <span class="bp">self</span><span class="o">.</span><span class="n">dboxes_xyxy</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span>
  152. <span class="bp">self</span><span class="o">.</span><span class="n">dboxes_xyxy</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span>
  153. <span class="bp">self</span><span class="o">.</span><span class="n">dboxes_xyxy</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">]</span>
  154. <span class="bp">self</span><span class="o">.</span><span class="n">dboxes_xyxy</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">]</span>
  155. <span class="nd">@property</span>
  156. <span class="k">def</span> <span class="nf">scale_xy</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  157. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale_xy_</span>
  158. <span class="nd">@property</span>
  159. <span class="k">def</span> <span class="nf">scale_wh</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  160. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale_wh_</span>
  161. <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">order</span><span class="o">=</span><span class="s2">&quot;xyxy&quot;</span><span class="p">):</span>
  162. <span class="k">if</span> <span class="n">order</span> <span class="o">==</span> <span class="s2">&quot;xyxy&quot;</span><span class="p">:</span>
  163. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes_xyxy</span>
  164. <span class="k">if</span> <span class="n">order</span> <span class="o">==</span> <span class="s2">&quot;xywh&quot;</span><span class="p">:</span>
  165. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span></div>
  166. <div class="viewcode-block" id="SSDPostPredictCallback"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.ssd_utils.SSDPostPredictCallback">[docs]</a><span class="k">class</span> <span class="nc">SSDPostPredictCallback</span><span class="p">(</span><span class="n">DetectionPostPredictionCallback</span><span class="p">):</span>
  167. <span class="sd">&quot;&quot;&quot;</span>
  168. <span class="sd"> post prediction callback module to convert and filter predictions coming from the SSD net to a format</span>
  169. <span class="sd"> used by all other detection models</span>
  170. <span class="sd"> &quot;&quot;&quot;</span>
  171. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">conf</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.001</span><span class="p">,</span> <span class="n">iou</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.6</span><span class="p">,</span> <span class="n">classes</span><span class="p">:</span> <span class="nb">list</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
  172. <span class="n">max_predictions</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">300</span><span class="p">,</span>
  173. <span class="n">nms_type</span><span class="p">:</span> <span class="n">NMS_Type</span> <span class="o">=</span> <span class="n">NMS_Type</span><span class="o">.</span><span class="n">ITERATIVE</span><span class="p">,</span>
  174. <span class="n">multi_label_per_box</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
  175. <span class="sd">&quot;&quot;&quot;</span>
  176. <span class="sd"> Predictions of SSD contain unnormalized probabilities for a background class,</span>
  177. <span class="sd"> together with confidences for all the dataset classes. Background will be utilized and discarded,</span>
  178. <span class="sd"> so this callback will return 0-based classes without background</span>
  179. <span class="sd"> :param conf: confidence threshold</span>
  180. <span class="sd"> :param iou: IoU threshold</span>
  181. <span class="sd"> :param classes: (optional list) filter by class</span>
  182. <span class="sd"> :param nms_type: the type of nms to use (iterative or matrix)</span>
  183. <span class="sd"> :param multi_label_per_box: whether to use re-use each box with all possible labels</span>
  184. <span class="sd"> (instead of the maximum confidence all confidences above threshold</span>
  185. <span class="sd"> will be sent to NMS)</span>
  186. <span class="sd"> &quot;&quot;&quot;</span>
  187. <span class="nb">super</span><span class="p">(</span><span class="n">SSDPostPredictCallback</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
  188. <span class="bp">self</span><span class="o">.</span><span class="n">conf</span> <span class="o">=</span> <span class="n">conf</span>
  189. <span class="bp">self</span><span class="o">.</span><span class="n">iou</span> <span class="o">=</span> <span class="n">iou</span>
  190. <span class="bp">self</span><span class="o">.</span><span class="n">nms_type</span> <span class="o">=</span> <span class="n">nms_type</span>
  191. <span class="bp">self</span><span class="o">.</span><span class="n">classes</span> <span class="o">=</span> <span class="n">classes</span>
  192. <span class="bp">self</span><span class="o">.</span><span class="n">max_predictions</span> <span class="o">=</span> <span class="n">max_predictions</span>
  193. <span class="bp">self</span><span class="o">.</span><span class="n">multi_label_per_box</span> <span class="o">=</span> <span class="n">multi_label_per_box</span>
  194. <div class="viewcode-block" id="SSDPostPredictCallback.forward"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.ssd_utils.SSDPostPredictCallback.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  195. <span class="n">nms_input</span> <span class="o">=</span> <span class="n">predictions</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
  196. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">nms_type</span> <span class="o">==</span> <span class="n">NMS_Type</span><span class="o">.</span><span class="n">ITERATIVE</span><span class="p">:</span>
  197. <span class="n">nms_res</span> <span class="o">=</span> <span class="n">non_max_suppression</span><span class="p">(</span><span class="n">nms_input</span><span class="p">,</span> <span class="n">conf_thres</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">conf</span><span class="p">,</span> <span class="n">iou_thres</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">iou</span><span class="p">,</span>
  198. <span class="n">multi_label_per_box</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">multi_label_per_box</span><span class="p">,</span> <span class="n">with_confidence</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  199. <span class="k">else</span><span class="p">:</span>
  200. <span class="n">nms_res</span> <span class="o">=</span> <span class="n">matrix_non_max_suppression</span><span class="p">(</span><span class="n">nms_input</span><span class="p">,</span> <span class="n">conf_thres</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">conf</span><span class="p">,</span>
  201. <span class="n">max_num_of_detections</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">max_predictions</span><span class="p">)</span>
  202. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_filter_max_predictions</span><span class="p">(</span><span class="n">nms_res</span><span class="p">)</span></div>
  203. <span class="k">def</span> <span class="nf">_filter_max_predictions</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">res</span><span class="p">:</span> <span class="n">List</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">:</span>
  204. <span class="n">res</span><span class="p">[:]</span> <span class="o">=</span> <span class="p">[</span><span class="n">im</span><span class="p">[:</span><span class="bp">self</span><span class="o">.</span><span class="n">max_predictions</span><span class="p">]</span> <span class="k">if</span> <span class="p">(</span><span class="n">im</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">im</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">&gt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">max_predictions</span><span class="p">)</span> <span class="k">else</span> <span class="n">im</span> <span class="k">for</span> <span class="n">im</span> <span class="ow">in</span> <span class="n">res</span><span class="p">]</span>
  205. <span class="k">return</span> <span class="n">res</span></div>
  206. </pre></div>
  207. </div>
  208. </div>
  209. <footer>
  210. <hr/>
  211. <div role="contentinfo">
  212. <p>&#169; Copyright 2021, SuperGradients team.</p>
  213. </div>
  214. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  215. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  216. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  217. </footer>
  218. </div>
  219. </div>
  220. </section>
  221. </div>
  222. <script>
  223. jQuery(function () {
  224. SphinxRtdTheme.Navigation.enable(true);
  225. });
  226. </script>
  227. </body>
  228. </html>
Discard
Some lines were truncated since they exceed the maximum allowed length of 500, please use a local Git client to see the full diff.
@@ -3,10 +3,11 @@
 <head>
 <head>
   <meta charset="utf-8" />
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
   <meta name="viewport" content="width=device-width, initial-scale=1.0" />
-  <title>super_gradients.training.utils.utils &mdash; SuperGradients 1.0 documentation</title>
+  <title>super_gradients.training.utils.utils &mdash; SuperGradients 3.0.3 documentation</title>
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
       <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
+      <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
   <!--[if lt IE 9]>
   <!--[if lt IE 9]>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
     <script src="../../../../_static/js/html5shiv.min.js"></script>
   <![endif]-->
   <![endif]-->
@@ -14,7 +15,9 @@
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/jquery.js"></script>
         <script src="../../../../_static/underscore.js"></script>
         <script src="../../../../_static/underscore.js"></script>
+        <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
         <script src="../../../../_static/doctools.js"></script>
         <script src="../../../../_static/doctools.js"></script>
+        <script src="../../../../_static/sphinx_highlight.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <script src="../../../../_static/js/theme.js"></script>
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="index" title="Index" href="../../../../genindex.html" />
     <link rel="search" title="Search" href="../../../../search.html" /> 
     <link rel="search" title="Search" href="../../../../search.html" /> 
@@ -35,30 +38,29 @@
   </form>
   </form>
 </div>
 </div>
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
         </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
-              <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
+              <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
 <ul>
 <ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
+<li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
 </ul>
 </ul>
-<p class="caption"><span class="caption-text">Technical Documentation</span></p>
+<p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
 <ul>
 <ul>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
 <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
-</ul>
-<p class="caption"><span class="caption-text">User Guide</span></p>
-<ul>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
-<li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
 </ul>
 </ul>
 
 
         </div>
         </div>
@@ -90,7 +92,7 @@
 <span class="kn">import</span> <span class="nn">time</span>
 <span class="kn">import</span> <span class="nn">time</span>
 <span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">lru_cache</span>
 <span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">lru_cache</span>
 <span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
 <span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
-<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Mapping</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">List</span>
+<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Mapping</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Dict</span>
 <span class="kn">from</span> <span class="nn">zipfile</span> <span class="kn">import</span> <span class="n">ZipFile</span>
 <span class="kn">from</span> <span class="nn">zipfile</span> <span class="kn">import</span> <span class="n">ZipFile</span>
 <span class="kn">import</span> <span class="nn">os</span>
 <span class="kn">import</span> <span class="nn">os</span>
 <span class="kn">from</span> <span class="nn">jsonschema</span> <span class="kn">import</span> <span class="n">validate</span>
 <span class="kn">from</span> <span class="nn">jsonschema</span> <span class="kn">import</span> <span class="n">validate</span>
@@ -112,28 +114,40 @@
 <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
 <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
 
 
 
 
-<div class="viewcode-block" id="convert_to_tensor"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.convert_to_tensor">[docs]</a><span class="k">def</span> <span class="nf">convert_to_tensor</span><span class="p">(</span><span class="n">array</span><span class="p">):</span>
+<span class="k">def</span> <span class="nf">empty_list</span><span class="p">():</span>
+    <span class="sd">&quot;&quot;&quot;Instantiate an empty list. This is a workaround to generate a list with a function call in hydra, instead of the &quot;[]&quot;.&quot;&quot;&quot;</span>
+    <span class="k">return</span> <span class="nb">list</span><span class="p">()</span>
+
+
+<div class="viewcode-block" id="convert_to_tensor"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.convert_to_tensor">[docs]</a><span class="k">def</span> <span class="nf">convert_to_tensor</span><span class="p">(</span><span class="n">array</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;Converts numpy arrays and lists to Torch tensors before calculation losses</span>
     <span class="sd">&quot;&quot;&quot;Converts numpy arrays and lists to Torch tensors before calculation losses</span>
 <span class="sd">    :param array: torch.tensor / Numpy array / List</span>
 <span class="sd">    :param array: torch.tensor / Numpy array / List</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
     <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">(</span><span class="n">array</span><span class="p">)</span> <span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">array</span><span class="p">)</span> <span class="o">!=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="k">else</span> <span class="n">array</span></div>
     <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">(</span><span class="n">array</span><span class="p">)</span> <span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">array</span><span class="p">)</span> <span class="o">!=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="k">else</span> <span class="n">array</span></div>
 
 
 
 
-<div class="viewcode-block" id="HpmStruct"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.HpmStruct">[docs]</a><span class="k">class</span> <span class="nc">HpmStruct</span><span class="p">:</span>
+<div class="viewcode-block" id="HpmStruct"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.HpmStruct">[docs]</a><span class="k">class</span> <span class="nc">HpmStruct</span><span class="p">:</span>
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">entries</span><span class="p">):</span>
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">entries</span><span class="p">):</span>
         <span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">entries</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">entries</span><span class="p">)</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">schema</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">schema</span> <span class="o">=</span> <span class="kc">None</span>
 
 
-<div class="viewcode-block" id="HpmStruct.set_schema"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.HpmStruct.set_schema">[docs]</a>    <span class="k">def</span> <span class="nf">set_schema</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">schema</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
+<div class="viewcode-block" id="HpmStruct.set_schema"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.HpmStruct.set_schema">[docs]</a>    <span class="k">def</span> <span class="nf">set_schema</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">schema</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">schema</span> <span class="o">=</span> <span class="n">schema</span></div>
         <span class="bp">self</span><span class="o">.</span><span class="n">schema</span> <span class="o">=</span> <span class="n">schema</span></div>
 
 
-<div class="viewcode-block" id="HpmStruct.override"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.HpmStruct.override">[docs]</a>    <span class="k">def</span> <span class="nf">override</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">entries</span><span class="p">):</span>
+<div class="viewcode-block" id="HpmStruct.override"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.HpmStruct.override">[docs]</a>    <span class="k">def</span> <span class="nf">override</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">entries</span><span class="p">):</span>
         <span class="n">recursive_override</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">,</span> <span class="n">entries</span><span class="p">)</span></div>
         <span class="n">recursive_override</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">,</span> <span class="n">entries</span><span class="p">)</span></div>
 
 
-<div class="viewcode-block" id="HpmStruct.to_dict"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.HpmStruct.to_dict">[docs]</a>    <span class="k">def</span> <span class="nf">to_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
-        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span></div>
+<div class="viewcode-block" id="HpmStruct.to_dict"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.HpmStruct.to_dict">[docs]</a>    <span class="k">def</span> <span class="nf">to_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">include_schema</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">dict</span><span
+        <span class="sd">&quot;&quot;&quot;Convert this HpmStruct instance into a dict.</span>
+<span class="sd">        :param include_schema: If True, also return the field &quot;schema&quot;</span>
+<span class="sd">        :return: Dict representation of this HpmStruct instance.</span>
+<span class="sd">        &quot;&quot;&quot;</span>
+        <span class="n">out_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
+        <span class="k">if</span> <span class="ow">not</span> <span class="n">include_schema</span><span class="p">:</span>
+            <span class="n">out_dict</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;schema&quot;</span><span class="p">)</span>
+        <span class="k">return</span> <span class="n">out_dict</span></div>
 
 
-<div class="viewcode-block" id="HpmStruct.validate"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.HpmStruct.validate">[docs]</a>    <span class="k">def</span> <span class="nf">validate</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+<div class="viewcode-block" id="HpmStruct.validate"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.HpmStruct.validate">[docs]</a>    <span class="k">def</span> <span class="nf">validate</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
         <span class="sd">&quot;&quot;&quot;</span>
         <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">        Validate the current dict values according to the provided schema</span>
 <span class="sd">        Validate the current dict values according to the provided schema</span>
 <span class="sd">        :raises</span>
 <span class="sd">        :raises</span>
@@ -147,16 +161,16 @@
             <span class="n">validate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">schema</span><span class="p">)</span></div></div>
             <span class="n">validate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">schema</span><span class="p">)</span></div></div>
 
 
 
 
-<div class="viewcode-block" id="WrappedModel"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.WrappedModel">[docs]</a><span class="k">class</span> <span class="nc">WrappedModel</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
+<div class="viewcode-block" id="WrappedModel"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.WrappedModel">[docs]</a><span class="k">class</span> <span class="nc">WrappedModel</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">module</span><span class="p">):</span>
     <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">module</span><span class="p">):</span>
         <span class="nb">super</span><span class="p">(</span><span class="n">WrappedModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
         <span class="nb">super</span><span class="p">(</span><span class="n">WrappedModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">module</span> <span class="o">=</span> <span class="n">module</span>  <span class="c1"># that I actually define.</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">module</span> <span class="o">=</span> <span class="n">module</span>  <span class="c1"># that I actually define.</span>
 
 
-<div class="viewcode-block" id="WrappedModel.forward"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.WrappedModel.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
+<div class="viewcode-block" id="WrappedModel.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.WrappedModel.forward">[docs]</a>    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
         <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">module</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></div></div>
         <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">module</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></div></div>
 
 
 
 
-<div class="viewcode-block" id="Timer"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.Timer">[docs]</a><span class="k">class</span> <span class="nc">Timer</span><span class="p">:</span>
+<div class="viewcode-block" id="Timer"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.Timer">[docs]</a><span class="k">class</span> <span class="nc">Timer</span><span class="p">:</span>
     <span class="sd">&quot;&quot;&quot;A class to measure time handling both GPU &amp; CPU processes</span>
     <span class="sd">&quot;&quot;&quot;A class to measure time handling both GPU &amp; CPU processes</span>
 <span class="sd">    Returns time in milliseconds&quot;&quot;&quot;</span>
 <span class="sd">    Returns time in milliseconds&quot;&quot;&quot;</span>
 
 
@@ -174,13 +188,13 @@
         <span class="k">else</span><span class="p">:</span>
         <span class="k">else</span><span class="p">:</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">starter</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ender</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">starter</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ender</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span>
 
 
-<div class="viewcode-block" id="Timer.start"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.Timer.start">[docs]</a>    <span class="k">def</span> <span class="nf">start</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+<div class="viewcode-block" id="Timer.start"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.Timer.start">[docs]</a>    <span class="k">def</span> <span class="nf">start</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
         <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">on_gpu</span><span class="p">:</span>
         <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">on_gpu</span><span class="p">:</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">starter</span><span class="o">.</span><span class="n">record</span><span class="p">()</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">starter</span><span class="o">.</span><span class="n">record</span><span class="p">()</span>
         <span class="k">else</span><span class="p">:</span>
         <span class="k">else</span><span class="p">:</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">starter</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span></div>
             <span class="bp">self</span><span class="o">.</span><span class="n">starter</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span></div>
 
 
-<div class="viewcode-block" id="Timer.stop"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.Timer.stop">[docs]</a>    <span class="k">def</span> <span class="nf">stop</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
+<div class="viewcode-block" id="Timer.stop"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.Timer.stop">[docs]</a>    <span class="k">def</span> <span class="nf">stop</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
         <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">on_gpu</span><span class="p">:</span>
         <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">on_gpu</span><span class="p">:</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">ender</span><span class="o">.</span><span class="n">record</span><span class="p">()</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">ender</span><span class="o">.</span><span class="n">record</span><span class="p">()</span>
             <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
             <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
@@ -193,7 +207,7 @@
         <span class="k">return</span> <span class="n">timer</span></div></div>
         <span class="k">return</span> <span class="n">timer</span></div></div>
 
 
 
 
-<div class="viewcode-block" id="AverageMeter"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.AverageMeter">[docs]</a><span class="k">class</span> <span class="nc">AverageMeter</span><span class="p">:</span>
+<span class="k">class</span> <span class="nc">AverageMeter</span><span class="p">:</span>
     <span class="sd">&quot;&quot;&quot;A class to calculate the average of a metric, for each batch</span>
     <span class="sd">&quot;&quot;&quot;A class to calculate the average of a metric, for each batch</span>
 <span class="sd">    during training/testing&quot;&quot;&quot;</span>
 <span class="sd">    during training/testing&quot;&quot;&quot;</span>
 
 
@@ -201,7 +215,7 @@
         <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="o">=</span> <span class="kc">None</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">_count</span> <span class="o">=</span> <span class="mi">0</span>
         <span class="bp">self</span><span class="o">.</span><span class="n">_count</span> <span class="o">=</span> <span class="mi">0</span>
 
 
-<div class="viewcode-block" id="AverageMeter.update"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.AverageMeter.update">[docs]</a>    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span cl
+    <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p
 
 
         <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
         <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
             <span class="n">value</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">value</span><span class="p">)</span>
             <span class="n">value</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">value</span><span class="p">)</span>
@@ -211,20 +225,20 @@
         <span class="k">else</span><span class="p">:</span>
         <span class="k">else</span><span class="p">:</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="o">+=</span> <span class="n">value</span> <span class="o">*</span> <span class="n">batch_size</span>
             <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="o">+=</span> <span class="n">value</span> <span class="o">*</span> <span class="n">batch_size</span>
 
 
-        <span class="bp">self</span><span class="o">.</span><span class="n">_count</span> <span class="o">+=</span> <span class="n">batch_size</span></div>
+        <span class="bp">self</span><span class="o">.</span><span class="n">_count</span> <span class="o">+=</span> <span class="n">batch_size</span>
 
 
     <span class="nd">@property</span>
     <span class="nd">@property</span>
     <span class="k">def</span> <span class="nf">average</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">average</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
         <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
         <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
             <span class="k">return</span> <span class="mi">0</span>
             <span class="k">return</span> <span class="mi">0</span>
         <span class="k">return</span> <span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">_count</span><span class="p">)</span><span class="o">.</span><span class="fm">__float__</span><span class="p">())</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span><span class="o">.</span><span class="
         <span class="k">return</span> <span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">_count</span><span class="p">)</span><span class="o">.</span><span class="fm">__float__</span><span class="p">())</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span><span class="o">.</span><span class="
-            <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">_count</span><span class="p">)</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span></div>
+            <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">_count</span><span class="p">)</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span>
 
 
         <span class="c1"># return (self._sum / self._count).__float__() if self._sum.dim() &lt; 1 or len(self._sum) == 1 \</span>
         <span class="c1"># return (self._sum / self._count).__float__() if self._sum.dim() &lt; 1 or len(self._sum) == 1 \</span>
         <span class="c1">#     else tuple((self._sum / self._count).cpu().numpy())</span>
         <span class="c1">#     else tuple((self._sum / self._count).cpu().numpy())</span>
 
 
 
 
-<div class="viewcode-block" id="tensor_container_to_device"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.tensor_container_to_device">[docs]</a><span class="k">def</span> <span class="nf">tensor_container_to_device</span><span class="p">(</span><span class="n">obj</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span 
+<div class="viewcode-block" id="tensor_container_to_device"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.tensor_container_to_device">[docs]</a><span class="k">def</span> <span class="nf">tensor_container_to_device</span><span class="p">(</span><span class="n">obj</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class=
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    recursively send compounded objects to device (sending all tensors to device and maintaining structure)</span>
 <span class="sd">    recursively send compounded objects to device (sending all tensors to device and maintaining structure)</span>
 <span class="sd">        :param obj           the object to send to device (list / tuple / tensor / dict)</span>
 <span class="sd">        :param obj           the object to send to device (list / tuple / tensor / dict)</span>
@@ -245,7 +259,7 @@
         <span class="k">return</span> <span class="n">obj</span></div>
         <span class="k">return</span> <span class="n">obj</span></div>
 
 
 
 
-<div class="viewcode-block" id="get_param"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.get_param">[docs]</a><span class="k">def</span> <span class="nf">get_param</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
+<div class="viewcode-block" id="get_param"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.get_param">[docs]</a><span class="k">def</span> <span class="nf">get_param</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Retrieves a param from a parameter object/dict. If the parameter does not exist, will return default_val.</span>
 <span class="sd">    Retrieves a param from a parameter object/dict. If the parameter does not exist, will return default_val.</span>
 <span class="sd">    In case the default_val is of type dictionary, and a value is found in the params - the function</span>
 <span class="sd">    In case the default_val is of type dictionary, and a value is found in the params - the function</span>
@@ -279,23 +293,23 @@
         <span class="k">return</span> <span class="n">default_val</span></div>
         <span class="k">return</span> <span class="n">default_val</span></div>
 
 
 
 
-<div class="viewcode-block" id="static_vars"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.static_vars">[docs]</a><span class="k">def</span> <span class="nf">static_vars</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
+<span class="k">def</span> <span class="nf">static_vars</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">decorate</span><span class="p">(</span><span class="n">func</span><span class="p">):</span>
     <span class="k">def</span> <span class="nf">decorate</span><span class="p">(</span><span class="n">func</span><span class="p">):</span>
         <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="p">:</span>
         <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="p">:</span>
             <span class="nb">setattr</span><span class="p">(</span><span class="n">func</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">kwargs</span><span class="p">[</span><span class="n">k</span><span class="p">])</span>
             <span class="nb">setattr</span><span class="p">(</span><span class="n">func</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">kwargs</span><span class="p">[</span><span class="n">k</span><span class="p">])</span>
         <span class="k">return</span> <span class="n">func</span>
         <span class="k">return</span> <span class="n">func</span>
 
 
-    <span class="k">return</span> <span class="n">decorate</span></div>
+    <span class="k">return</span> <span class="n">decorate</span>
 
 
 
 
-<div class="viewcode-block" id="print_once"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.print_once">[docs]</a><span class="nd">@static_vars</span><span class="p">(</span><span class="n">printed</span><span class="o">=</span><span class="nb">set</span><span class="p">())</span>
+<span class="nd">@static_vars</span><span class="p">(</span><span class="n">printed</span><span class="o">=</span><span class="nb">set</span><span class="p">())</span>
 <span class="k">def</span> <span class="nf">print_once</span><span class="p">(</span><span class="n">s</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
 <span class="k">def</span> <span class="nf">print_once</span><span class="p">(</span><span class="n">s</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
     <span class="k">if</span> <span class="n">s</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">print_once</span><span class="o">.</span><span class="n">printed</span><span class="p">:</span>
     <span class="k">if</span> <span class="n">s</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">print_once</span><span class="o">.</span><span class="n">printed</span><span class="p">:</span>
         <span class="n">print_once</span><span class="o">.</span><span class="n">printed</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">s</span><span class="p">)</span>
         <span class="n">print_once</span><span class="o">.</span><span class="n">printed</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">s</span><span class="p">)</span>
-        <span class="nb">print</span><span class="p">(</span><span class="n">s</span><span class="p">)</span></div>
+        <span class="nb">print</span><span class="p">(</span><span class="n">s</span><span class="p">)</span>
 
 
 
 
-<div class="viewcode-block" id="move_state_dict_to_device"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.move_state_dict_to_device">[docs]</a><span class="k">def</span> <span class="nf">move_state_dict_to_device</span><span class="p">(</span><span class="n">model_sd</span><span class="p">,</span> <span class="n">device</span><span class="p">):</span>
+<span class="k">def</span> <span class="nf">move_state_dict_to_device</span><span class="p">(</span><span class="n">model_sd</span><span class="p">,</span> <span class="n">device</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Moving model state dict tensors to target device (cuda or cpu)</span>
 <span class="sd">    Moving model state dict tensors to target device (cuda or cpu)</span>
 <span class="sd">    :param model_sd: model state dict</span>
 <span class="sd">    :param model_sd: model state dict</span>
@@ -303,10 +317,10 @@
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
     <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">model_sd</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
     <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">model_sd</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
         <span class="n">model_sd</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
         <span class="n">model_sd</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
-    <span class="k">return</span> <span class="n">model_sd</span></div>
+    <span class="k">return</span> <span class="n">model_sd</span>
 
 
 
 
-<div class="viewcode-block" id="random_seed"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.random_seed">[docs]</a><span class="k">def</span> <span class="nf">random_seed</span><span class="p">(</span><span class="n">is_ddp</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">seed</span><span class="p">):</span>
+<div class="viewcode-block" id="random_seed"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.random_seed">[docs]</a><span class="k">def</span> <span class="nf">random_seed</span><span class="p">(</span><span class="n">is_ddp</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">seed</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Sets random seed of numpy, torch and random.</span>
 <span class="sd">    Sets random seed of numpy, torch and random.</span>
 
 
@@ -321,7 +335,7 @@
     <span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="n">seed</span> <span class="o">+</span> <span class="n">rank</span><span class="p">)</span></div>
     <span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="n">seed</span> <span class="o">+</span> <span class="n">rank</span><span class="p">)</span></div>
 
 
 
 
-<div class="viewcode-block" id="load_func"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.load_func">[docs]</a><span class="k">def</span> <span class="nf">load_func</span><span class="p">(</span><span class="n">dotpath</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
+<span class="k">def</span> <span class="nf">load_func</span><span class="p">(</span><span class="n">dotpath</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    load function in module.  function is right-most segment.</span>
 <span class="sd">    load function in module.  function is right-most segment.</span>
 
 
@@ -332,10 +346,10 @@
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
     <span class="n">module_</span><span class="p">,</span> <span class="n">func</span> <span class="o">=</span> <span class="n">dotpath</span><span class="o">.</span><span class="n">rsplit</span><span class="p">(</span><span class="s2">&quot;.&quot;</span><span class="p">,</span> <span class="n">maxsplit</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
     <span class="n">module_</span><span class="p">,</span> <span class="n">func</span> <span class="o">=</span> <span class="n">dotpath</span><span class="o">.</span><span class="n">rsplit</span><span class="p">(</span><span class="s2">&quot;.&quot;</span><span class="p">,</span> <span class="n">maxsplit</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
     <span class="n">m</span> <span class="o">=</span> <span class="n">import_module</span><span class="p">(</span><span class="n">module_</span><span class="p">)</span>
     <span class="n">m</span> <span class="o">=</span> <span class="n">import_module</span><span class="p">(</span><span class="n">module_</span><span class="p">)</span>
-    <span class="k">return</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">func</span><span class="p">)</span></div>
+    <span class="k">return</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">func</span><span class="p">)</span>
 
 
 
 
-<div class="viewcode-block" id="get_filename_suffix_by_framework"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.get_filename_suffix_by_framework">[docs]</a><span class="k">def</span> <span class="nf">get_filename_suffix_by_framework</span><span class="p">(</span><span class="n">framework</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
+<span class="k">def</span> <span class="nf">get_filename_suffix_by_framework</span><span class="p">(</span><span class="n">framework</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Return the file extension of framework.</span>
 <span class="sd">    Return the file extension of framework.</span>
 
 
@@ -359,10 +373,10 @@
     <span class="k">if</span> <span class="n">framework</span><span class="o">.</span><span class="n">upper</span><span class="p">()</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">frameworks_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
     <span class="k">if</span> <span class="n">framework</span><span class="o">.</span><span class="n">upper</span><span class="p">()</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">frameworks_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
         <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Unsupported framework: </span><span class="si">{</span><span class="n">framework</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
         <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Unsupported framework: </span><span class="si">{</span><span class="n">framework</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
 
 
-    <span class="k">return</span> <span class="n">frameworks_dict</span><span class="p">[</span><span class="n">framework</span><span class="o">.</span><span class="n">upper</span><span class="p">()]</span></div>
+    <span class="k">return</span> <span class="n">frameworks_dict</span><span class="p">[</span><span class="n">framework</span><span class="o">.</span><span class="n">upper</span><span class="p">()]</span>
 
 
 
 
-<div class="viewcode-block" id="check_models_have_same_weights"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.check_models_have_same_weights">[docs]</a><span class="k">def</span> <span class="nf">check_models_have_same_weights</span><span class="p">(</span><span class="n">model_1</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module
+<span class="k">def</span> <span class="nf">check_models_have_same_weights</span><span class="p">(</span><span class="n">model_1</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">model_2</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Checks whether two networks have the same weights</span>
 <span class="sd">    Checks whether two networks have the same weights</span>
 
 
@@ -382,10 +396,10 @@
     <span class="k">if</span> <span class="n">models_differ</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
     <span class="k">if</span> <span class="n">models_differ</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
         <span class="k">return</span> <span class="kc">True</span>
         <span class="k">return</span> <span class="kc">True</span>
     <span class="k">else</span><span class="p">:</span>
     <span class="k">else</span><span class="p">:</span>
-        <span class="k">return</span> <span class="kc">False</span></div>
+        <span class="k">return</span> <span class="kc">False</span>
 
 
 
 
-<div class="viewcode-block" id="recursive_override"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.recursive_override">[docs]</a><span class="k">def</span> <span class="nf">recursive_override</span><span class="p">(</span><span class="n">base</span><span class="p">:</span> <span class="nb">dict</span><span class="p">,</span> <span class="n">extension</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span
+<span class="k">def</span> <span class="nf">recursive_override</span><span class="p">(</span><span class="n">base</span><span class="p">:</span> <span class="nb">dict</span><span class="p">,</span> <span class="n">extension</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
     <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">extension</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
     <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">extension</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
         <span class="k">if</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">base</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">base</span><span class="p">:</span>
             <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">Mapping</span><span class="p">):</span>
             <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">Mapping</span><span class="p">):</span>
@@ -393,10 +407,10 @@
             <span class="k">else</span><span class="p">:</span>
             <span class="k">else</span><span class="p">:</span>
                 <span class="n">base</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">extension</span><span class="p">[</span><span class="n">k</span><span class="p">]</span>
                 <span class="n">base</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">extension</span><span class="p">[</span><span class="n">k</span><span class="p">]</span>
         <span class="k">else</span><span class="p">:</span>
         <span class="k">else</span><span class="p">:</span>
-            <span class="n">base</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">extension</span><span class="p">[</span><span class="n">k</span><span class="p">]</span></div>
+            <span class="n">base</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">extension</span><span class="p">[</span><span class="n">k</span><span class="p">]</span>
 
 
 
 
-<div class="viewcode-block" id="download_and_unzip_from_url"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.download_and_unzip_from_url">[docs]</a><span class="k">def</span> <span class="nf">download_and_unzip_from_url</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="nb">dir</span><span class="o">=</span><span class="s1">&#39;.&#39;</span><span class="p">,</span> <span class="n">unzip</sp
+<span class="k">def</span> <span class="nf">download_and_unzip_from_url</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="nb">dir</span><span class="o">=</span><span class="s1">&#39;.&#39;</span><span class="p">,</span> <span class="n">unzip</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">delete</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Downloads a zip file from url to dir, and unzips it.</span>
 <span class="sd">    Downloads a zip file from url to dir, and unzips it.</span>
 
 
@@ -431,10 +445,10 @@
     <span class="nb">dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="nb">dir</span><span class="p">)</span>
     <span class="nb">dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="nb">dir</span><span class="p">)</span>
     <span class="nb">dir</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>  <span class="c1"># make directory</span>
     <span class="nb">dir</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>  <span class="c1"># make directory</span>
     <span class="k">for</span> <span class="n">u</span> <span class="ow">in</span> <span class="p">[</span><span class="n">url</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Path</span><span class="p">))</span> <span class="k">else</span> <span class="n">url</span><span class="p">:</span
     <span class="k">for</span> <span class="n">u</span> <span class="ow">in</span> <span class="p">[</span><span class="n">url</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Path</span><span class="p">))</span> <span class="k">else</span> <span class="n">url</span><span class="p">:</span
-        <span class="n">download_one</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="nb">dir</span><span class="p">)</span></div>
+        <span class="n">download_one</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="nb">dir</span><span class="p">)</span>
 
 
 
 
-<div class="viewcode-block" id="download_and_untar_from_url"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.download_and_untar_from_url">[docs]</a><span class="k">def</span> <span class="nf">download_and_untar_from_url</span><span class="p">(</span><span class="n">urls</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="nb">dir</span><spa
+<span class="k">def</span> <span class="nf">download_and_untar_from_url</span><span class="p">(</span><span class="n">urls</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="nb">dir</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Path</span><span class="p">]</span> <span class="o">=</span> <span cl
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Download a file from url and untar.</span>
 <span class="sd">    Download a file from url and untar.</span>
 
 
@@ -460,10 +474,10 @@
         <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Extracting to </span><span class="si">{</span><span class="nb">dir</span><span class="si">}</span><span class="s1">...&#39;</span><span class="p">)</span>
         <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Extracting to </span><span class="si">{</span><span class="nb">dir</span><span class="si">}</span><span class="s1">...&#39;</span><span class="p">)</span>
         <span class="k">with</span> <span class="n">tarfile</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">filepath</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="n">modes</span><span class="p">[</span><span class="n">filepath</span><span class="o">.</span><span class="n">suffix</span><span class="p">])</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
         <span class="k">with</span> <span class="n">tarfile</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">filepath</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="n">modes</span><span class="p">[</span><span class="n">filepath</span><span class="o">.</span><span class="n">suffix</span><span class="p">])</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
             <span class="n">f</span><span class="o">.</span><span class="n">extractall</span><span class="p">(</span><span class="nb">dir</span><span class="p">)</span>
             <span class="n">f</span><span class="o">.</span><span class="n">extractall</span><span class="p">(</span><span class="nb">dir</span><span class="p">)</span>
-        <span class="n">filepath</span><span class="o">.</span><span class="n">unlink</span><span class="p">()</span></div>
+        <span class="n">filepath</span><span class="o">.</span><span class="n">unlink</span><span class="p">()</span>
 
 
 
 
-<div class="viewcode-block" id="make_divisible"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.make_divisible">[docs]</a><span class="k">def</span> <span class="nf">make_divisible</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">divisor</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">ce
+<span class="k">def</span> <span class="nf">make_divisible</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">divisor</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">ceil</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    Returns x evenly divisible by divisor.</span>
 <span class="sd">    Returns x evenly divisible by divisor.</span>
 <span class="sd">    If ceil=True it will return the closest larger number to the original x, and ceil=False the closest smaller number.</span>
 <span class="sd">    If ceil=True it will return the closest larger number to the original x, and ceil=False the closest smaller number.</span>
@@ -471,10 +485,10 @@
     <span class="k">if</span> <span class="n">ceil</span><span class="p">:</span>
     <span class="k">if</span> <span class="n">ceil</span><span class="p">:</span>
         <span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">x</span> <span class="o">/</span> <span class="n">divisor</span><span class="p">)</span> <span class="o">*</span> <span class="n">divisor</span>
         <span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">x</span> <span class="o">/</span> <span class="n">divisor</span><span class="p">)</span> <span class="o">*</span> <span class="n">divisor</span>
     <span class="k">else</span><span class="p">:</span>
     <span class="k">else</span><span class="p">:</span>
-        <span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">x</span> <span class="o">/</span> <span class="n">divisor</span><span class="p">)</span> <span class="o">*</span> <span class="n">divisor</span></div>
+        <span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">x</span> <span class="o">/</span> <span class="n">divisor</span><span class="p">)</span> <span class="o">*</span> <span class="n">divisor</span>
 
 
 
 
-<div class="viewcode-block" id="check_img_size_divisibility"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.check_img_size_divisibility">[docs]</a><span class="k">def</span> <span class="nf">check_img_size_divisibility</span><span class="p">(</span><span class="n">img_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">stride</span><span class="p">:</span> <span class="nb">int</spa
+<span class="k">def</span> <span class="nf">check_img_size_divisibility</span><span class="p">(</span><span class="n">img_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">stride</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">bool</span><span class="p">,</spa
     <span class="sd">&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;</span>
 <span class="sd">    :param img_size: Int, the size of the image (H or W).</span>
 <span class="sd">    :param img_size: Int, the size of the image (H or W).</span>
 <span class="sd">    :param stride: Int, the number to check if img_size is divisible by.</span>
 <span class="sd">    :param stride: Int, the number to check if img_size is divisible by.</span>
@@ -486,22 +500,22 @@
     <span class="k">if</span> <span class="n">new_size</span> <span class="o">!=</span> <span class="n">img_size</span><span class="p">:</span>
     <span class="k">if</span> <span class="n">new_size</span> <span class="o">!=</span> <span class="n">img_size</span><span class="p">:</span>
         <span class="k">return</span> <span class="kc">False</span><span class="p">,</span> <span class="p">(</span><span class="n">new_size</span><span class="p">,</span> <span class="n">make_divisible</span><span class="p">(</span><span class="n">img_size</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="n">stride</span><span class="p">),</span> <span class="n">ceil</span><span class="o">=</span><span class="kc">False</span><span class="p">))</span
         <span class="k">return</span> <span class="kc">False</span><span class="p">,</span> <span class="p">(</span><span class="n">new_size</span><span class="p">,</span> <span class="n">make_divisible</span><span class="p">(</span><span class="n">img_size</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="n">stride</span><span class="p">),</span> <span class="n">ceil</span><span class="o">=</span><span class="kc">False</span><span class="p">))</span
     <span class="k">else</span><span class="p">:</span>
     <span class="k">else</span><span class="p">:</span>
-        <span class="k">return</span> <span class="kc">True</span><span class="p">,</span> <span class="kc">None</span></div>
+        <span class="k">return</span> <span class="kc">True</span><span class="p">,</span> <span class="kc">None</span>
 
 
 
 
-<div class="viewcode-block" id="get_orientation_key"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.get_orientation_key">[docs]</a><span class="nd">@lru_cache</span><span class="p">(</span><span class="kc">None</span><span class="p">)</span>
+<span class="nd">@lru_cache</span><span class="p">(</span><span class="kc">None</span><span class="p">)</span>
 <span class="k">def</span> <span class="nf">get_orientation_key</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
 <span class="k">def</span> <span class="nf">get_orientation_key</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
     <span class="sd">&quot;&quot;&quot;Get the orientation key according to PIL, which is useful to get the image size for instance</span>
     <span class="sd">&quot;&quot;&quot;Get the orientation key according to PIL, which is useful to get the image size for instance</span>
 <span class="sd">    :return: Orientation key according to PIL&quot;&quot;&quot;</span>
 <span class="sd">    :return: Orientation key according to PIL&quot;&quot;&quot;</span>
     <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">ExifTags</span><span class="o">.</span><span class="n">TAGS</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
     <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">ExifTags</span><span class="o">.</span><span class="n">TAGS</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
         <span class="k">if</span> <span class="n">value</span> <span class="o">==</span> <span class="s1">&#39;Orientation&#39;</span><span class="p">:</span>
         <span class="k">if</span> <span class="n">value</span> <span class="o">==</span> <span class="s1">&#39;Orientation&#39;</span><span class="p">:</span>
-            <span class="k">return</span> <span class="n">key</span></div>
+            <span class="k">return</span> <span class="n">key</span>
 
 
 
 
-<div class="viewcode-block" id="exif_size"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.exif_size">[docs]</a><span class="k">def</span> <span class="nf">exif_size</span><span class="p">(</span><span class="n">image</span><span class="p">:</span> <span class="n">Image</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <s
+<span class="k">def</span> <span class="nf">exif_size</span><span class="p">(</span><span class="n">image</span><span class="p">:</span> <span class="n">Image</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]:</span>
     <span class="sd">&quot;&quot;&quot;Get the size of image.</span>
     <span class="sd">&quot;&quot;&quot;Get the size of image.</span>
 <span class="sd">    :param image:   The image to get size from</span>
 <span class="sd">    :param image:   The image to get size from</span>
-<span class="sd">    :return:        (width, height)</span>
+<span class="sd">    :return:        (height, width)</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 <span class="sd">    &quot;&quot;&quot;</span>
 
 
     <span class="n">orientation_key</span> <span class="o">=</span> <span class="n">get_orientation_key</span><span class="p">()</span>
     <span class="n">orientation_key</span> <span class="o">=</span> <span class="n">get_orientation_key</span><span class="p">()</span>
@@ -519,14 +533,27 @@
                 <span class="n">image_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">image_size</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">image_size</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
                 <span class="n">image_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">image_size</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">image_size</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
     <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>
     <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>
         <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Caught Exception trying to rotate: &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">ex</span><span class="p">))</span>
         <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Caught Exception trying to rotate: &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">ex</span><span class="p">))</span>
-    <span class="n">height</span><span class="p">,</span> <span class="n">width</span> <span class="o">=</span> <span class="n">image_size</span>
-    <span class="k">return</span> <span class="n">width</span><span class="p">,</span> <span class="n">height</span></div>
+    <span class="n">width</span><span class="p">,</span> <span class="n">height</span> <span class="o">=</span> <span class="n">image_size</span>
+    <span class="k">return</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span>
 
 
 
 
-<div class="viewcode-block" id="get_image_size_from_path"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.get_image_size_from_path">[docs]</a><span class="k">def</span> <span class="nf">get_image_size_from_path</span><span class="p">(</span><span class="n">img_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span c
+<span class="k">def</span> <span class="nf">get_image_size_from_path</span><span class="p">(</span><span class="n">img_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]:</span>
     <span class="sd">&quot;&quot;&quot;Get the image size of an image at a specific path&quot;&quot;&quot;</span>
     <span class="sd">&quot;&quot;&quot;Get the image size of an image at a specific path&quot;&quot;&quot;</span>
     <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">img_path</span><span class="p">,</span> <span class="s1">&#39;rb&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
     <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">img_path</span><span class="p">,</span> <span class="s1">&#39;rb&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
-        <span class="k">return</span> <span class="n">exif_size</span><span class="p">(</span><span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">f</span><span class="p">))</span></div>
+        <span class="k">return</span> <span class="n">exif_size</span><span class="p">(</span><span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">f</span><span class="p">))</span>
+
+
+<span class="k">def</span> <span class="nf">override_default_params_without_nones</span><span class="p">(</span><span class="n">params</span><span class="p">:</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">default_params</span><span class="p">:</span> <span class="n">Dict</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Dict</span><span class="p">:</span>
+    <span class="sd">&quot;&quot;&quot;</span>
+<span class="sd">    Helper method for overriding default dictionary&#39;s entries excluding entries with None values.</span>
+<span class="sd">    :param params: dict, output dictionary which will take the defaults.</span>
+<span class="sd">    :param default_params: dict, dictionary for the defaults.</span>
+<span class="sd">    :return: dict, params after manipulation,</span>
+<span class="sd">    &quot;&quot;&quot;</span>
+    <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="n">default_params</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
+        <span class="k">if</span> <span class="n">key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">params</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> <span class="ow">or</span> <span class="n">params</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
+            <span class="n">params</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">val</span>
+    <span class="k">return</span> <span class="n">params</span>
 </pre></div>
 </pre></div>
 
 
            </div>
            </div>
@@ -556,4 +583,4 @@
   </script> 
   </script> 
 
 
 </body>
 </body>
-</html>
+</html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.utils.version_utils &mdash; SuperGradients 3.0.3 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
  11. <!--[if lt IE 9]>
  12. <script src="../../../../_static/js/html5shiv.min.js"></script>
  13. <![endif]-->
  14. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  15. <script src="../../../../_static/jquery.js"></script>
  16. <script src="../../../../_static/underscore.js"></script>
  17. <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
  18. <script src="../../../../_static/doctools.js"></script>
  19. <script src="../../../../_static/sphinx_highlight.js"></script>
  20. <script src="../../../../_static/js/theme.js"></script>
  21. <link rel="index" title="Index" href="../../../../genindex.html" />
  22. <link rel="search" title="Search" href="../../../../search.html" />
  23. </head>
  24. <body class="wy-body-for-nav">
  25. <div class="wy-grid-for-nav">
  26. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  27. <div class="wy-side-scroll">
  28. <div class="wy-side-nav-search" >
  29. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  30. </a>
  31. <div role="search">
  32. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  33. <input type="text" name="q" placeholder="Search docs" />
  34. <input type="hidden" name="check_keywords" value="yes" />
  35. <input type="hidden" name="area" value="default" />
  36. </form>
  37. </div>
  38. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  39. <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
  40. <ul>
  41. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
  42. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
  45. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
  46. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
  47. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
  57. </ul>
  58. <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
  59. <ul>
  60. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  61. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  62. </ul>
  63. </div>
  64. </div>
  65. </nav>
  66. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  67. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  68. <a href="../../../../index.html">SuperGradients</a>
  69. </nav>
  70. <div class="wy-nav-content">
  71. <div class="rst-content">
  72. <div role="navigation" aria-label="Page navigation">
  73. <ul class="wy-breadcrumbs">
  74. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  75. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  76. <li>super_gradients.training.utils.version_utils</li>
  77. <li class="wy-breadcrumbs-aside">
  78. </li>
  79. </ul>
  80. <hr/>
  81. </div>
  82. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  83. <div itemprop="articleBody">
  84. <h1>Source code for super_gradients.training.utils.version_utils</h1><div class="highlight"><pre>
  85. <span></span><span class="kn">import</span> <span class="nn">torch</span>
  86. <span class="n">_TORCH_VERSION_MAJOR_MINOR_</span> <span class="o">=</span> <span class="nb">tuple</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="nb">int</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">version</span><span class="o">.</span><span class="n">__version__</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;.&quot;</span><span class="p">)[:</span><span class="mi">2</span><span class="p">]))</span>
  87. <span class="n">__all__</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;torch_version_is_greater_or_equal&quot;</span><span class="p">]</span>
  88. <div class="viewcode-block" id="torch_version_is_greater_or_equal"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.torch_version_is_greater_or_equal">[docs]</a><span class="k">def</span> <span class="nf">torch_version_is_greater_or_equal</span><span class="p">(</span><span class="n">major</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">minor</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
  89. <span class="n">version</span> <span class="o">=</span> <span class="p">(</span><span class="n">major</span><span class="p">,</span> <span class="n">minor</span><span class="p">)</span>
  90. <span class="k">return</span> <span class="n">_TORCH_VERSION_MAJOR_MINOR_</span> <span class="o">&gt;=</span> <span class="n">version</span></div>
  91. </pre></div>
  92. </div>
  93. </div>
  94. <footer>
  95. <hr/>
  96. <div role="contentinfo">
  97. <p>&#169; Copyright 2021, SuperGradients team.</p>
  98. </div>
  99. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  100. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  101. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  102. </footer>
  103. </div>
  104. </div>
  105. </section>
  106. </div>
  107. <script>
  108. jQuery(function () {
  109. SphinxRtdTheme.Navigation.enable(true);
  110. });
  111. </script>
  112. </body>
  113. </html>
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.utils.weight_averaging_utils &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.utils.weight_averaging_utils</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.utils.weight_averaging_utils</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">os</span>
  84. <span class="kn">import</span> <span class="nn">torch</span>
  85. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  86. <span class="kn">import</span> <span class="nn">pkg_resources</span>
  87. <span class="kn">from</span> <span class="nn">super_gradients.training</span> <span class="kn">import</span> <span class="n">utils</span> <span class="k">as</span> <span class="n">core_utils</span>
  88. <span class="kn">from</span> <span class="nn">super_gradients.training.utils.utils</span> <span class="kn">import</span> <span class="n">move_state_dict_to_device</span>
  89. <div class="viewcode-block" id="ModelWeightAveraging"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.weight_averaging_utils.ModelWeightAveraging">[docs]</a><span class="k">class</span> <span class="nc">ModelWeightAveraging</span><span class="p">:</span>
  90. <span class="sd">&quot;&quot;&quot;</span>
  91. <span class="sd"> Utils class for managing the averaging of the best several snapshots into a single model.</span>
  92. <span class="sd"> A snapshot dictionary file and the average model will be saved / updated at every epoch and evaluated only when</span>
  93. <span class="sd"> training is completed. The snapshot file will only be deleted upon completing the training.</span>
  94. <span class="sd"> The snapshot dict will be managed on cpu.</span>
  95. <span class="sd"> &quot;&quot;&quot;</span>
  96. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ckpt_dir</span><span class="p">,</span>
  97. <span class="n">greater_is_better</span><span class="p">,</span>
  98. <span class="n">source_ckpt_folder_name</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">metric_to_watch</span><span class="o">=</span><span class="s1">&#39;acc&#39;</span><span class="p">,</span>
  99. <span class="n">metric_idx</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">load_checkpoint</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  100. <span class="n">number_of_models_to_average</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
  101. <span class="n">model_checkpoints_location</span><span class="o">=</span><span class="s1">&#39;local&#39;</span>
  102. <span class="p">):</span>
  103. <span class="sd">&quot;&quot;&quot;</span>
  104. <span class="sd"> Init the ModelWeightAveraging</span>
  105. <span class="sd"> :param checkpoint_dir: the directory where the checkpoints are saved</span>
  106. <span class="sd"> :param metric_to_watch: monitoring loss or acc, will be identical to that which determines best_model</span>
  107. <span class="sd"> :param metric_idx:</span>
  108. <span class="sd"> :param load_checkpoint: whether to load pre-existing snapshot dict.</span>
  109. <span class="sd"> :param number_of_models_to_average: number of models to average</span>
  110. <span class="sd"> &quot;&quot;&quot;</span>
  111. <span class="k">if</span> <span class="n">source_ckpt_folder_name</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  112. <span class="n">source_ckpt_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">source_ckpt_folder_name</span><span class="p">,</span> <span class="s1">&#39;averaging_snapshots.pkl&#39;</span><span class="p">)</span>
  113. <span class="n">source_ckpt_file</span> <span class="o">=</span> <span class="n">pkg_resources</span><span class="o">.</span><span class="n">resource_filename</span><span class="p">(</span><span class="s1">&#39;checkpoints&#39;</span><span class="p">,</span> <span class="n">source_ckpt_file</span><span class="p">)</span>
  114. <span class="bp">self</span><span class="o">.</span><span class="n">averaging_snapshots_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">ckpt_dir</span><span class="p">,</span> <span class="s1">&#39;averaging_snapshots.pkl&#39;</span><span class="p">)</span>
  115. <span class="bp">self</span><span class="o">.</span><span class="n">number_of_models_to_average</span> <span class="o">=</span> <span class="n">number_of_models_to_average</span>
  116. <span class="bp">self</span><span class="o">.</span><span class="n">metric_to_watch</span> <span class="o">=</span> <span class="n">metric_to_watch</span>
  117. <span class="bp">self</span><span class="o">.</span><span class="n">metric_idx</span> <span class="o">=</span> <span class="n">metric_idx</span>
  118. <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="o">=</span> <span class="n">greater_is_better</span>
  119. <span class="c1"># if continuing training, copy previous snapshot dict if exist</span>
  120. <span class="k">if</span> <span class="n">load_checkpoint</span> <span class="ow">and</span> <span class="n">source_ckpt_folder_name</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isfile</span><span class="p">(</span><span class="n">source_ckpt_file</span><span class="p">):</span>
  121. <span class="n">averaging_snapshots_dict</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">load_checkpoint</span><span class="p">(</span><span class="n">ckpt_destination_dir</span><span class="o">=</span><span class="n">ckpt_dir</span><span class="p">,</span>
  122. <span class="n">source_ckpt_folder_name</span><span class="o">=</span><span class="n">source_ckpt_folder_name</span><span class="p">,</span>
  123. <span class="n">ckpt_filename</span><span class="o">=</span><span class="s2">&quot;averaging_snapshots.pkl&quot;</span><span class="p">,</span>
  124. <span class="n">load_weights_only</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
  125. <span class="n">model_checkpoints_location</span><span class="o">=</span><span class="n">model_checkpoints_location</span><span class="p">,</span>
  126. <span class="n">overwrite_local_ckpt</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  127. <span class="k">else</span><span class="p">:</span>
  128. <span class="n">averaging_snapshots_dict</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;snapshot&#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">i</span><span class="p">):</span> <span class="kc">None</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">number_of_models_to_average</span><span class="p">)}</span>
  129. <span class="c1"># if metric to watch is acc, hold a zero array, if loss hold inf array</span>
  130. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span><span class="p">:</span>
  131. <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">&#39;snapshots_metric&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">inf</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">number_of_models_to_average</span><span class="p">)</span>
  132. <span class="k">else</span><span class="p">:</span>
  133. <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">&#39;snapshots_metric&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">inf</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">number_of_models_to_average</span><span class="p">)</span>
  134. <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">averaging_snapshots_dict</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">averaging_snapshots_file</span><span class="p">)</span>
  135. <div class="viewcode-block" id="ModelWeightAveraging.update_snapshots_dict"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.weight_averaging_utils.ModelWeightAveraging.update_snapshots_dict">[docs]</a> <span class="k">def</span> <span class="nf">update_snapshots_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">validation_results_tuple</span><span class="p">):</span>
  136. <span class="sd">&quot;&quot;&quot;</span>
  137. <span class="sd"> Update the snapshot dict and returns the updated average model for saving</span>
  138. <span class="sd"> :param model: the latest model</span>
  139. <span class="sd"> :param validation_results_tuple: performance of the latest model</span>
  140. <span class="sd"> &quot;&quot;&quot;</span>
  141. <span class="n">averaging_snapshots_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_averaging_snapshots_dict</span><span class="p">()</span>
  142. <span class="c1"># IF CURRENT MODEL IS BETTER, TAKING HIS PLACE IN ACC LIST AND OVERWRITE THE NEW AVERAGE</span>
  143. <span class="n">require_update</span><span class="p">,</span> <span class="n">update_ind</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_is_better</span><span class="p">(</span><span class="n">averaging_snapshots_dict</span><span class="p">,</span> <span class="n">validation_results_tuple</span><span class="p">)</span>
  144. <span class="k">if</span> <span class="n">require_update</span><span class="p">:</span>
  145. <span class="c1"># moving state dict to cpu</span>
  146. <span class="n">new_sd</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span>
  147. <span class="n">new_sd</span> <span class="o">=</span> <span class="n">move_state_dict_to_device</span><span class="p">(</span><span class="n">new_sd</span><span class="p">,</span> <span class="s1">&#39;cpu&#39;</span><span class="p">)</span>
  148. <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">&#39;snapshot&#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">update_ind</span><span class="p">)]</span> <span class="o">=</span> <span class="n">new_sd</span>
  149. <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">&#39;snapshots_metric&#39;</span><span class="p">][</span><span class="n">update_ind</span><span class="p">]</span> <span class="o">=</span> <span class="n">validation_results_tuple</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_idx</span><span class="p">]</span>
  150. <span class="k">return</span> <span class="n">averaging_snapshots_dict</span></div>
  151. <div class="viewcode-block" id="ModelWeightAveraging.get_average_model"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.weight_averaging_utils.ModelWeightAveraging.get_average_model">[docs]</a> <span class="k">def</span> <span class="nf">get_average_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">validation_results_tuple</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  152. <span class="sd">&quot;&quot;&quot;</span>
  153. <span class="sd"> Returns the averaged model</span>
  154. <span class="sd"> :param model: will be used to determine arch</span>
  155. <span class="sd"> :param validation_results_tuple: if provided, will update the average model before returning</span>
  156. <span class="sd"> :param target_device: if provided, return sd on target device</span>
  157. <span class="sd"> &quot;&quot;&quot;</span>
  158. <span class="c1"># If validation tuple is provided, update the average model</span>
  159. <span class="k">if</span> <span class="n">validation_results_tuple</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  160. <span class="n">averaging_snapshots_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">update_snapshots_dict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">validation_results_tuple</span><span class="p">)</span>
  161. <span class="k">else</span><span class="p">:</span>
  162. <span class="n">averaging_snapshots_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_averaging_snapshots_dict</span><span class="p">()</span>
  163. <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">averaging_snapshots_dict</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">averaging_snapshots_file</span><span class="p">)</span>
  164. <span class="n">average_model_sd</span> <span class="o">=</span> <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">&#39;snapshot0&#39;</span><span class="p">]</span>
  165. <span class="k">for</span> <span class="n">n_model</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">number_of_models_to_average</span><span class="p">):</span>
  166. <span class="k">if</span> <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">&#39;snapshot&#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">n_model</span><span class="p">)]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  167. <span class="n">net_sd</span> <span class="o">=</span> <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">&#39;snapshot&#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">n_model</span><span class="p">)]</span>
  168. <span class="c1"># USING MOVING AVERAGE</span>
  169. <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">average_model_sd</span><span class="p">:</span>
  170. <span class="n">average_model_sd</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">true_divide</span><span class="p">(</span>
  171. <span class="n">average_model_sd</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">*</span> <span class="n">n_model</span> <span class="o">+</span> <span class="n">net_sd</span><span class="p">[</span><span class="n">key</span><span class="p">],</span>
  172. <span class="p">(</span><span class="n">n_model</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
  173. <span class="k">return</span> <span class="n">average_model_sd</span></div>
  174. <div class="viewcode-block" id="ModelWeightAveraging.cleanup"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.weight_averaging_utils.ModelWeightAveraging.cleanup">[docs]</a> <span class="k">def</span> <span class="nf">cleanup</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  175. <span class="sd">&quot;&quot;&quot;</span>
  176. <span class="sd"> Delete snapshot file when reaching the last epoch</span>
  177. <span class="sd"> &quot;&quot;&quot;</span>
  178. <span class="n">os</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">averaging_snapshots_file</span><span class="p">)</span></div>
  179. <span class="k">def</span> <span class="nf">_is_better</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">averaging_snapshots_dict</span><span class="p">,</span> <span class="n">validation_results_tuple</span><span class="p">):</span>
  180. <span class="sd">&quot;&quot;&quot;</span>
  181. <span class="sd"> Determines if the new model is better according to the specified metrics</span>
  182. <span class="sd"> :param averaging_snapshots_dict: snapshot dict</span>
  183. <span class="sd"> :param validation_results_tuple: latest model performance</span>
  184. <span class="sd"> &quot;&quot;&quot;</span>
  185. <span class="n">snapshot_metric_array</span> <span class="o">=</span> <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">&#39;snapshots_metric&#39;</span><span class="p">]</span>
  186. <span class="n">val</span> <span class="o">=</span> <span class="n">validation_results_tuple</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_idx</span><span class="p">]</span>
  187. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span><span class="p">:</span>
  188. <span class="n">update_ind</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmin</span><span class="p">(</span><span class="n">snapshot_metric_array</span><span class="p">)</span>
  189. <span class="k">else</span><span class="p">:</span>
  190. <span class="n">update_ind</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">snapshot_metric_array</span><span class="p">)</span>
  191. <span class="k">if</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="ow">and</span> <span class="n">val</span> <span class="o">&gt;</span> <span class="n">snapshot_metric_array</span><span class="p">[</span><span class="n">update_ind</span><span class="p">])</span> <span class="ow">or</span> <span class="p">(</span>
  192. <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="ow">and</span> <span class="n">val</span> <span class="o">&lt;</span> <span class="n">snapshot_metric_array</span><span class="p">[</span><span class="n">update_ind</span><span class="p">]):</span>
  193. <span class="k">return</span> <span class="kc">True</span><span class="p">,</span> <span class="n">update_ind</span>
  194. <span class="k">return</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">None</span>
  195. <span class="k">def</span> <span class="nf">_get_averaging_snapshots_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  196. <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">averaging_snapshots_file</span><span class="p">)</span></div>
  197. </pre></div>
  198. </div>
  199. </div>
  200. <footer>
  201. <hr/>
  202. <div role="contentinfo">
  203. <p>&#169; Copyright 2021, SuperGradients team.</p>
  204. </div>
  205. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  206. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  207. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  208. </footer>
  209. </div>
  210. </div>
  211. </section>
  212. </div>
  213. <script>
  214. jQuery(function () {
  215. SphinxRtdTheme.Navigation.enable(true);
  216. });
  217. </script>
  218. </body>
  219. </html>
Discard